mtmd: add more sanity checks (#21047)

This commit is contained in:
Xuan-Son Nguyen
2026-03-27 11:00:52 +01:00
committed by GitHub
parent 20197b6fe3
commit 871f1a2d2f
6 changed files with 48 additions and 13 deletions
+10
View File
@@ -127,6 +127,7 @@ struct decode_embd_batch {
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
GGML_ASSERT(n_tokens > 0 && n_pos_per_embd > 0 && n_mmproj_embd > 0);
pos .resize(n_tokens * n_pos_per_embd);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
@@ -157,6 +158,7 @@ struct decode_embd_batch {
// M-RoPE for image
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
GGML_ASSERT(n_pos_per_embd == 4);
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
seq_id_0[0] = seq_id;
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
@@ -192,6 +194,7 @@ struct decode_embd_batch {
}
llama_batch get_view(int offset, int n_tokens) {
GGML_ASSERT(offset >= 0 && n_tokens > 0 && offset + n_tokens <= batch.n_tokens);
llama_pos * pos_ptr;
pos_view.clear();
pos_view.reserve(n_tokens * n_pos_per_embd);
@@ -235,6 +238,7 @@ int32_t mtmd_helper_decode_image_chunk(
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
GGML_ASSERT(n_batch > 0);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
@@ -312,6 +316,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
int32_t n_batch,
bool logits_last,
llama_pos * new_n_past) {
GGML_ASSERT(n_batch > 0);
int32_t ret;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
@@ -508,6 +513,11 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
fseek(f, 0, SEEK_END);
long file_size = ftell(f);
fseek(f, 0, SEEK_SET);
if (file_size < 0) {
LOG_ERR("Failed to get file size of %s\n", fname);
fclose(f);
return nullptr;
}
buf.resize(file_size);
size_t n_read = fread(buf.data(), 1, file_size, f);