mtmd: add Gemma 4 audio conformer encoder support (#21421)

* mtmd: add Gemma 4 audio conformer encoder support

Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer.

Architecture:
- 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm
- Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm
- Full self-attention with sinusoidal RPE and sliding window mask (24)
- Logit softcapping at 50.0, ClippableLinear clamping
- Output: 1024 → 1536 → RMSNorm → multimodal embedder

Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a):
- HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3
- Standard periodic Hann window (320 samples), zero-padded to FFT size
- Semicausal left-padding (frame_length/2 samples)
- Frame count matched to PyTorch (unfold formula)
- No pre-emphasis, no Whisper-style normalization
- Mel cosine similarity vs PyTorch: 0.9998

Key fixes:
- Tensor loading dedup: prevent get_tensor() from creating duplicate
  entries in ctx_data. Fixed with std::set guard.
- ClippableLinear clamp_info loading moved after per-layer tensors.
- Sliding window mask (24 positions) matching PyTorch context_size.
- Skip Whisper normalization for Gemma4 mel output.

Tested on E2B and E4B with CPU and Vulkan backends.
Transcribes: "Glad to see things are going well and business is starting
to pick up" (matching ground truth).

Ref: #21325
This commit is contained in:
Stephen Cox
2026-04-13 00:15:26 +12:00
committed by GitHub
parent 9e209c5aee
commit 547765a93e
11 changed files with 649 additions and 29 deletions
+15
View File
@@ -181,6 +181,21 @@
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
// gemma4 audio conformer
#define TN_A_MM_INP_PROJ "mm.a.input_projection.%s"
#define TN_A_MM_SOFT_EMB_N "mm.a.soft_emb_norm.%s"
#define TN_A_INP_PROJ "a.input_projection.%s"
#define TN_A_CONV1D "a.conv1d.%d.%s"
#define TN_A_CONV1D_NORM "a.conv1d.%d.norm.%s"
#define TN_A_OUT_PROJ "a.pre_encode.out.%s"
#define TN_A_ATTN_PRE_NORM "%s.blk.%d.attn_pre_norm.%s"
#define TN_A_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s"
#define TN_A_ATTN_K_REL "%s.blk.%d.attn_k_rel.%s"
#define TN_A_PER_DIM_SCALE "%s.blk.%d.per_dim_scale.%s"
#define TN_A_PER_DIM_K_SCALE "%s.blk.%d.per_dim_k_scale.%s"
#define TN_A_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
#define TN_A_FFN_POST_NORM_1 "%s.blk.%d.ffn_post_norm_1.%s"
// mobilenetv5 (gemma3n) definitions
#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight"
#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias"