mtmd: add Gemma 4 audio conformer encoder support (#21421)
* mtmd: add Gemma 4 audio conformer encoder support Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer. Architecture: - 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm - Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm - Full self-attention with sinusoidal RPE and sliding window mask (24) - Logit softcapping at 50.0, ClippableLinear clamping - Output: 1024 → 1536 → RMSNorm → multimodal embedder Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a): - HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3 - Standard periodic Hann window (320 samples), zero-padded to FFT size - Semicausal left-padding (frame_length/2 samples) - Frame count matched to PyTorch (unfold formula) - No pre-emphasis, no Whisper-style normalization - Mel cosine similarity vs PyTorch: 0.9998 Key fixes: - Tensor loading dedup: prevent get_tensor() from creating duplicate entries in ctx_data. Fixed with std::set guard. - ClippableLinear clamp_info loading moved after per-layer tensors. - Sliding window mask (24 positions) matching PyTorch context_size. - Skip Whisper normalization for Gemma4 mel output. Tested on E2B and E4B with CPU and Vulkan backends. Transcribes: "Glad to see things are going well and business is starting to pick up" (matching ground truth). Ref: #21325
This commit is contained in:
+158
-4
@@ -931,6 +931,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_conformer>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_gemma4a>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
|
||||
@@ -1459,6 +1463,16 @@ struct clip_model_loader {
|
||||
hparams.audio_window_len = 400;
|
||||
hparams.audio_hop_len = 160;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
// Gemma4 feature_extraction_gemma4.py:
|
||||
// frame_length_ms=20 -> 320 samples, n_fft=512, hop=10ms -> 160
|
||||
hparams.audio_chunk_len = 0; // no fixed-length padding
|
||||
hparams.audio_sample_rate = 16000;
|
||||
hparams.audio_n_fft = 512;
|
||||
hparams.audio_window_len = 320; // 20ms frame (NOT 25ms/400)
|
||||
hparams.audio_hop_len = 160;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
hparams.image_pad_color = {127, 127, 127};
|
||||
@@ -1561,16 +1575,21 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
// helper function
|
||||
std::unordered_set<std::string> loaded_tensor_names;
|
||||
auto get_tensor = [&](const std::string & name, bool required = true) {
|
||||
// Each tensor should only be loaded once; duplicates indicate a bug
|
||||
if (loaded_tensor_names.count(name)) {
|
||||
throw std::runtime_error(string_format("%s: tensor already loaded: %s\n", __func__, name.c_str()));
|
||||
}
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
|
||||
if (!cur && required) {
|
||||
throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
|
||||
}
|
||||
if (cur) {
|
||||
tensors_to_load.push_back(cur);
|
||||
// add tensors to context
|
||||
ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
|
||||
ggml_set_name(data_tensor, cur->name);
|
||||
loaded_tensor_names.insert(name);
|
||||
cur = data_tensor;
|
||||
}
|
||||
return cur;
|
||||
@@ -2186,6 +2205,76 @@ struct clip_model_loader {
|
||||
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
for (int i = 0; i < 2; i++) {
|
||||
model.sscp_conv_w[i] = get_tensor(string_format(TN_A_CONV1D, i, "weight"));
|
||||
model.sscp_conv_b[i] = get_tensor(string_format(TN_A_CONV1D, i, "bias"), false);
|
||||
model.sscp_norm_w[i] = get_tensor(string_format(TN_A_CONV1D_NORM, i, "weight"), false);
|
||||
}
|
||||
model.sscp_inp_proj_w = get_tensor(string_format(TN_A_INP_PROJ, "weight"));
|
||||
model.sscp_inp_proj_b = get_tensor(string_format(TN_A_INP_PROJ, "bias"), false);
|
||||
model.audio_out_proj_w = get_tensor(string_format(TN_A_OUT_PROJ, "weight"), false);
|
||||
model.audio_out_proj_b = get_tensor(string_format(TN_A_OUT_PROJ, "bias"), false);
|
||||
// audio multimodal embedder (mm.a.* namespace, not mm.*)
|
||||
model.mm_soft_emb_norm_w = get_tensor(string_format(TN_A_MM_SOFT_EMB_N, "weight"), false);
|
||||
model.mm_input_proj_w = get_tensor(string_format(TN_A_MM_INP_PROJ, "weight"), false);
|
||||
|
||||
// Per-layer tensors NOT loaded by the generic loop above
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = model.layers[il];
|
||||
|
||||
// Gemma4 audio conformer-specific tensors
|
||||
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
|
||||
layer.attn_pre_norm_w = get_tensor(string_format(TN_A_ATTN_PRE_NORM, prefix, il, "weight"), false);
|
||||
layer.per_dim_scale_w = get_tensor(string_format(TN_A_PER_DIM_SCALE, prefix, il, "weight"), false);
|
||||
layer.per_dim_k_scale_w = get_tensor(string_format(TN_A_PER_DIM_K_SCALE, prefix, il, "weight"), false);
|
||||
layer.attn_k_rel_w = get_tensor(string_format(TN_A_ATTN_K_REL, prefix, il, "weight"), false);
|
||||
|
||||
// Convolution module
|
||||
// Note: conv_norm / norm_conv are swapped in GGUF due to
|
||||
// upstream tensor_mapping.py, so we load them in reverse order
|
||||
layer.norm_conv_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"), false);
|
||||
layer.norm_conv_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false);
|
||||
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
|
||||
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false);
|
||||
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
|
||||
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false);
|
||||
layer.conv_norm_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"), false);
|
||||
layer.conv_norm_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false);
|
||||
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
|
||||
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false);
|
||||
|
||||
// FFN2 (second half-step)
|
||||
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
|
||||
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
|
||||
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false);
|
||||
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
|
||||
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false);
|
||||
layer.ff_post_norm_1_w = get_tensor(string_format(TN_A_FFN_POST_NORM_1, prefix, il, "weight"), false);
|
||||
}
|
||||
|
||||
// Load clamp info for ClippableLinear AFTER all tensors are loaded
|
||||
for (auto * tensor : tensors_to_load) {
|
||||
std::string name = tensor->name;
|
||||
if (string_ends_with(name, ".weight")) {
|
||||
std::string name_inp_max = name;
|
||||
std::string name_inp_min = name;
|
||||
std::string name_out_max = name;
|
||||
std::string name_out_min = name;
|
||||
string_replace_all(name_inp_max, ".weight", ".input_max");
|
||||
string_replace_all(name_inp_min, ".weight", ".input_min");
|
||||
string_replace_all(name_out_max, ".weight", ".output_max");
|
||||
string_replace_all(name_out_min, ".weight", ".output_min");
|
||||
model.clamp_info_map[name] = {
|
||||
get_scalar(name_inp_max, FLT_MAX),
|
||||
get_scalar(name_inp_min, -FLT_MAX),
|
||||
get_scalar(name_out_max, FLT_MAX),
|
||||
get_scalar(name_out_min, -FLT_MAX)
|
||||
};
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
for (int i : {0, 2, 3, 5, 6}) {
|
||||
@@ -2246,7 +2335,10 @@ struct clip_model_loader {
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (auto & t : tensors_to_load) {
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
const size_t offset = tensor_offset[t->name];
|
||||
GGML_ASSERT(cur && "tensor not found in ctx_data");
|
||||
auto it_off = tensor_offset.find(t->name);
|
||||
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
|
||||
const size_t offset = it_off->second;
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
|
||||
@@ -2266,6 +2358,7 @@ struct clip_model_loader {
|
||||
|
||||
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct support_info_op {
|
||||
@@ -2538,8 +2631,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
|
||||
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
|
||||
// we can remove this check when we implement audio support for Gemma 3N
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|
||||
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
|
||||
}
|
||||
|
||||
if (loader.has_audio && !skip_audio) {
|
||||
@@ -2893,6 +2985,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
{
|
||||
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
|
||||
// O = floor((I - 1) / 2) + 1
|
||||
int n = img->nx;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
n = (n - 1) / 2 + 1;
|
||||
}
|
||||
n_patches = n;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported projector type");
|
||||
}
|
||||
@@ -3352,6 +3454,56 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
set_input_i32("pos_w", pos_data);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
const auto & img0 = imgs.entries.front();
|
||||
// Compute n_pos matching SSCP output: two stride-2 convs
|
||||
int n_pos = img0->nx;
|
||||
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
|
||||
|
||||
// Chunked local attention: blocked causal mask and RPE
|
||||
const int chunk_size = 12;
|
||||
const int max_past = 12;
|
||||
const int context_size = chunk_size + max_past;
|
||||
const int num_blocks = (n_pos + chunk_size - 1) / chunk_size;
|
||||
|
||||
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
|
||||
{
|
||||
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
|
||||
for (int b = 0; b < num_blocks; b++) {
|
||||
for (int q = 0; q < chunk_size; q++) {
|
||||
int gq = b * chunk_size + q;
|
||||
for (int k = 0; k < context_size; k++) {
|
||||
int gk = b * chunk_size - max_past + k;
|
||||
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
|
||||
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
set_input_f32("kq_mask", mask);
|
||||
}
|
||||
|
||||
// Sinusoidal RPE: 13 positions [12, 11, ..., 0]
|
||||
{
|
||||
const int n_embd = ctx->model.hparams.n_embd;
|
||||
const int num_timescales = n_embd / 2;
|
||||
const float log_timescale_increment = logf(10000.0f) / std::max(num_timescales - 1, 1);
|
||||
const int rpe_len = max_past + 1;
|
||||
std::vector<float> pos_emb(n_embd * rpe_len, 0.0f);
|
||||
for (int p = 0; p < rpe_len; p++) {
|
||||
float position = (float)(max_past - p);
|
||||
for (int i = 0; i < num_timescales; i++) {
|
||||
float inv_ts = expf(-(float)i * log_timescale_increment);
|
||||
float scaled = position * inv_ts;
|
||||
pos_emb[p * n_embd + i] = sinf(scaled);
|
||||
pos_emb[p * n_embd + i + num_timescales] = cosf(scaled);
|
||||
}
|
||||
}
|
||||
set_input_f32("pos_emb", pos_emb);
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
@@ -3516,6 +3668,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
return ctx->model.position_embeddings->ne[0];
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
return ctx->model.hparams.projection_dim;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
return ctx->model.mm_ffn_down_w->ne[1];
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user