mtmd: add pos_0 to mtmd_image_tokens_get_decoder_pos (breaking change) (#22082)

* mtmd: add pos_0 to mtmd_image_tokens_get_decoder_pos

* fix build
This commit is contained in:
Xuan-Son Nguyen
2026-04-19 11:57:21 +02:00
committed by GitHub
parent bcdcc1044f
commit 19124078be
5 changed files with 22 additions and 17 deletions
+1 -1
View File
@@ -42,7 +42,7 @@ int main(void) {
const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk); const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
// get position of the last token, which should be (nx - 1, ny - 1) // get position of the last token, which should be (nx - 1, ny - 1)
struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, n_tokens - 1); struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, 0, n_tokens - 1);
size_t nx = pos.x + 1; size_t nx = pos.x + 1;
size_t ny = pos.y + 1; size_t ny = pos.y + 1;
const char * id = mtmd_image_tokens_get_id(image_tokens); const char * id = mtmd_image_tokens_get_id(image_tokens);
+10 -10
View File
@@ -114,10 +114,10 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
return n_pos; return n_pos;
} }
void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) { void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, llama_pos pos_0, mtmd_decoder_pos * out_pos) {
size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks); size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks);
for (size_t i = 0; i < n_tokens; i++) { for (size_t i = 0; i < n_tokens; i++) {
out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i); out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, pos_0, i);
} }
} }
@@ -163,15 +163,15 @@ struct decode_embd_batch {
} }
// M-RoPE for image // M-RoPE for image
void set_position_mrope_2d(llama_pos pos_0, const std::vector<mtmd_decoder_pos> & rel_pos, llama_seq_id seq_id) { void set_position_mrope_2d(const std::vector<mtmd_decoder_pos> & rel_pos, llama_seq_id seq_id) {
GGML_ASSERT(n_pos_per_embd == 4); GGML_ASSERT(n_pos_per_embd == 4);
GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens); GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens);
seq_id_0[0] = seq_id; seq_id_0[0] = seq_id;
for (int32_t i = 0; i < batch.n_tokens; i++) { for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i ] = pos_0 + rel_pos[i].t; pos[i ] = rel_pos[i].t;
pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y; pos[i + batch.n_tokens ] = rel_pos[i].y;
pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x; pos[i + batch.n_tokens * 2] = rel_pos[i].x;
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused pos[i + batch.n_tokens * 3] = rel_pos[i].z;
} }
for (int i = 0; i < batch.n_tokens; i++) { for (int i = 0; i < batch.n_tokens; i++) {
batch.n_seq_id[i] = 1; batch.n_seq_id[i] = 1;
@@ -188,7 +188,7 @@ struct decode_embd_batch {
pos[i ] = pos_0 + i; pos[i ] = pos_0 + i;
pos[i + batch.n_tokens ] = pos_0 + i; pos[i + batch.n_tokens ] = pos_0 + i;
pos[i + batch.n_tokens * 2] = pos_0 + i; pos[i + batch.n_tokens * 2] = pos_0 + i;
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused pos[i + batch.n_tokens * 3] = pos_0 + i;
} }
for (int i = 0; i < batch.n_tokens; i++) { for (int i = 0; i < batch.n_tokens; i++) {
batch.n_seq_id[i] = 1; batch.n_seq_id[i] = 1;
@@ -268,8 +268,8 @@ int32_t mtmd_helper_decode_image_chunk(
} }
const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
std::vector<mtmd_decoder_pos> rel_pos(n_tokens); std::vector<mtmd_decoder_pos> rel_pos(n_tokens);
mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data()); mtmd_helper_image_get_decoder_pos(image_tokens, n_past, rel_pos.data());
batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id); batch_embd.set_position_mrope_2d(rel_pos, seq_id);
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
batch_embd.set_position_mrope_1d(n_past, seq_id); batch_embd.set_position_mrope_1d(n_past, seq_id);
} else { } else {
+1 -1
View File
@@ -49,7 +49,7 @@ MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks);
// helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE // helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE
// out_pos must have length == mtmd_helper_get_n_tokens(image) // out_pos must have length == mtmd_helper_get_n_tokens(image)
MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, struct mtmd_decoder_pos * out_pos); MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, llama_pos pos_0, struct mtmd_decoder_pos * out_pos);
// helper function that automatically: // helper function that automatically:
// 1. run llama_decode() on text chunks // 1. run llama_decode() on text chunks
+7 -4
View File
@@ -1246,11 +1246,14 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
return image_tokens->ny; return image_tokens->ny;
} }
mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i) { mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) {
mtmd_decoder_pos pos; mtmd_decoder_pos pos;
pos.t = 0; // M-RoPE logic
pos.x = i % image_tokens->nx; // TODO: support other types of position encoding if needed
pos.y = i / image_tokens->nx; pos.t = pos_0;
pos.x = pos_0 + (i % image_tokens->nx);
pos.y = pos_0 + (i / image_tokens->nx);
pos.z = 0; // unused for now
return pos; return pos;
} }
+3 -1
View File
@@ -196,11 +196,13 @@ struct mtmd_decoder_pos {
uint32_t t; uint32_t t;
uint32_t x; uint32_t x;
uint32_t y; uint32_t y;
uint32_t z; // unused for now, reserved for future use
}; };
// get position for decoder attention, to be used by M-RoPE models // get position for decoder attention, to be used by M-RoPE models
// i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1 // i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1
// pos_0 is the absolute position of the first token
// return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position) // return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position)
MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i); MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i);
// tokenize an input text prompt and a list of bitmaps (images/audio) // tokenize an input text prompt and a list of bitmaps (images/audio)
// the prompt must have the input image marker (default: "<__media__>") in it // the prompt must have the input image marker (default: "<__media__>") in it