mtmd: add mtmd_image_tokens_get_decoder_pos() API (#21851)

* mtmd: add mtmd_image_tokens_get_decoder_pos() API

* consistent naming

* fix build
This commit is contained in:
Xuan-Son Nguyen
2026-04-14 16:07:41 +02:00
committed by GitHub
parent 1f30ac0cea
commit 707c0b7a6e
5 changed files with 49 additions and 17 deletions
+18 -13
View File
@@ -114,6 +114,13 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
return n_pos;
}
void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) {
size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks);
for (size_t i = 0; i < n_tokens; i++) {
out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i);
}
}
// helper struct to make working with embd batch easier
// note: this will be removed after llama_batch_ext refactoring
struct decode_embd_batch {
@@ -156,18 +163,15 @@ struct decode_embd_batch {
}
// M-RoPE for image
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
void set_position_mrope_2d(llama_pos pos_0, const std::vector<mtmd_decoder_pos> & rel_pos, llama_seq_id seq_id) {
GGML_ASSERT(n_pos_per_embd == 4);
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens);
seq_id_0[0] = seq_id;
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
int i = y * nx + x;
pos[i ] = pos_0;
pos[i + batch.n_tokens ] = pos_0 + y;
pos[i + batch.n_tokens * 2] = pos_0 + x;
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
}
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i ] = pos_0 + rel_pos[i].t;
pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y;
pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x;
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
}
for (int i = 0; i < batch.n_tokens; i++) {
batch.n_seq_id[i] = 1;
@@ -262,9 +266,10 @@ int32_t mtmd_helper_decode_image_chunk(
LOG_ERR("failed to decode chunk: image tokens are null\n");
return -1;
}
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id);
const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
std::vector<mtmd_decoder_pos> rel_pos(n_tokens);
mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data());
batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id);
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
batch_embd.set_position_mrope_1d(n_past, seq_id);
} else {