mtmd: add Gemma 4 audio conformer encoder support (#21421)

* mtmd: add Gemma 4 audio conformer encoder support

Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer.

Architecture:
- 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm
- Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm
- Full self-attention with sinusoidal RPE and sliding window mask (24)
- Logit softcapping at 50.0, ClippableLinear clamping
- Output: 1024 → 1536 → RMSNorm → multimodal embedder

Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a):
- HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3
- Standard periodic Hann window (320 samples), zero-padded to FFT size
- Semicausal left-padding (frame_length/2 samples)
- Frame count matched to PyTorch (unfold formula)
- No pre-emphasis, no Whisper-style normalization
- Mel cosine similarity vs PyTorch: 0.9998

Key fixes:
- Tensor loading dedup: prevent get_tensor() from creating duplicate
  entries in ctx_data. Fixed with std::set guard.
- ClippableLinear clamp_info loading moved after per-layer tensors.
- Sliding window mask (24 positions) matching PyTorch context_size.
- Skip Whisper normalization for Gemma4 mel output.

Tested on E2B and E4B with CPU and Vulkan backends.
Transcribes: "Glad to see things are going well and business is starting
to pick up" (matching ground truth).

Ref: #21325
This commit is contained in:
Stephen Cox
2026-04-13 00:15:26 +12:00
committed by GitHub
parent 9e209c5aee
commit 547765a93e
11 changed files with 649 additions and 29 deletions
+126 -22
View File
@@ -8,6 +8,7 @@
#include <vector>
#include <fstream>
#include <algorithm>
#include <functional>
// some of the code here is copied from whisper.cpp
@@ -37,23 +38,36 @@ void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
float fmin,
float fmax,
bool slaney_area_norm,
float scale) {
float scale,
bool use_htk) {
GGML_ASSERT(n_mel > 0 && n_fft > 1);
if (fmax <= 0.0f) {
fmax = 0.5f * sample_rate;
}
// Slaney scale (matches librosa default)
const double min_log_hz = 1000.0;
const double lin_slope = 3 / 200.;
const double min_log_mel = min_log_hz * lin_slope;
const double log_step = log(6.4) / 27.0;
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
};
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
};
std::function<double(double)> hz_to_mel;
std::function<double(double)> mel_to_hz;
if (use_htk) {
hz_to_mel = [](const double f_hz) -> double {
return 2595.0 * log10(1.0 + f_hz / 700.0);
};
mel_to_hz = [](const double m) -> double {
return 700.0 * (pow(10.0, m / 2595.0) - 1.0);
};
} else {
// Slaney scale (matches librosa default)
const double min_log_hz = 1000.0;
const double lin_slope = 3 / 200.;
const double min_log_mel = min_log_hz * lin_slope;
const double log_step = log(6.4) / 27.0;
hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
};
mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
};
}
// infer N_fft from n_fft_bins
const double bin_hz_step = double(sample_rate) / double(n_fft);
@@ -257,10 +271,13 @@ struct filter_params {
int32_t hann_window_size;
int32_t hop_length;
int32_t sample_rate;
bool center_padding = false;
float preemph = 0.f;
bool no_padding = false;
bool center_padding = false;
float preemph = 0.f;
bool use_natural_log = false;
bool norm_per_feature = false;
bool use_magnitude = false; // |X| instead of |X|^2
float mel_floor = 5.960464477539063e-08f;
};
static void log_mel_spectrogram_worker_thread(int ith,
@@ -301,10 +318,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
// FFT
fft(cache, fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
// Calculate modulus^2 (power) or modulus (magnitude)
for (int j = 0; j < n_fft_bins; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
float power = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
fft_out[j] = params.use_magnitude ? sqrtf(power) : power;
}
// mel spectrogram
@@ -324,9 +341,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
for (; k < n_fft_bins; k++) {
sum += fft_out[k] * filters.data[j * n_fft_bins + k];
}
sum = std::max(sum, (double)params.mel_floor);
sum = params.use_natural_log
? log(sum + 5.960464477539063e-08)
: log10(std::max(sum, 1e-10));
? log(sum)
: log10(sum);
out.data[j * out.n_len + i] = sum;
}
}
@@ -360,7 +378,12 @@ static bool log_mel_spectrogram(
// Padding
std::vector<float> samples_padded;
if (params.center_padding) {
if (params.no_padding) {
// no padding, use samples as-is
samples_padded = std::vector<float>(samples, samples + n_samples);
samples = samples_padded.data();
n_samples = samples_padded.size();
} else if (params.center_padding) {
const auto pad_amount = frame_size / 2;
samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
@@ -464,8 +487,8 @@ static bool log_mel_spectrogram(
out.data[i * out.n_len + j] = 0.0;
}
}
} else {
// clamping and normalization
} else if (!params.no_padding) {
// Whisper-style clamping and normalization (NOT used by Gemma4)
double mmax = -1e20;
for (int i = 0; i < out.n_mel*out.n_len; i++) {
if (out.data[i] > mmax) {
@@ -627,6 +650,87 @@ bool mtmd_audio_preprocessor_conformer::preprocess(const float *
return true;
}
//
// mtmd_audio_preprocessor_gemma4a
//
void mtmd_audio_preprocessor_gemma4a::initialize() {
cache.fill_sin_cos_table(hparams.audio_n_fft);
// Standard periodic Hann window, zero-padded to FFT size
cache.hann_window.assign(hparams.audio_n_fft, 0.0f);
for (uint32_t i = 0; i < (uint32_t)hparams.audio_window_len; i++) {
cache.hann_window[i] = 0.5f - 0.5f * cosf((2.0f * (float)M_PI * i) / hparams.audio_window_len);
}
// HTK mel scale, no Slaney area normalization
cache.fill_mel_filterbank_matrix(
hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate,
0.0f, hparams.audio_sample_rate / 2.0f,
/*slaney_area_norm=*/ false,
/*scale=*/ 1.0f,
/*use_htk=*/ true
);
}
bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
if (n_samples == 0) {
return false;
}
GGML_ASSERT(!cache.sin_vals.empty());
GGML_ASSERT(!cache.cos_vals.empty());
GGML_ASSERT(!cache.filters.data.empty());
filter_params params;
params.n_mel = hparams.n_mel_bins;
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
params.hann_window_size = hparams.audio_n_fft; // window is zero-padded to FFT size
params.hop_length = hparams.audio_hop_len;
params.sample_rate = hparams.audio_sample_rate;
params.no_padding = true;
params.center_padding = false;
params.preemph = 0.0f;
params.use_natural_log = true;
params.use_magnitude = true;
params.mel_floor = 0.001f;
params.norm_per_feature = false;
// Split into 30-second chunks (model context limit, ~750 tokens each)
const size_t chunk_samples = 30 * hparams.audio_sample_rate;
for (size_t off = 0; off < n_samples; off += chunk_samples) {
const float * chunk_ptr = samples + off;
size_t chunk_len = std::min(chunk_samples, n_samples - off);
// Semicausal left-padding + right-padding to match PyTorch frame count
const int pad_left = hparams.audio_window_len / 2;
const int fft_size = hparams.audio_n_fft;
const int hop = hparams.audio_hop_len;
const int n_with_left = (int)chunk_len + pad_left;
// PyTorch: unfold(size=frame_length+1, step=hop) on semicausal-padded waveform
const int pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1;
const int n_padded_needed = (pt_frames - 1) * hop + fft_size;
const int total_pad = std::max((int)(n_padded_needed - (int)chunk_len), pad_left);
std::vector<float> padded_samples(total_pad + chunk_len, 0.0f);
std::copy(chunk_ptr, chunk_ptr + chunk_len, padded_samples.data() + pad_left);
mtmd_audio_mel out_chunk;
bool ok = log_mel_spectrogram(padded_samples.data(), padded_samples.size(), 4, params, cache, out_chunk);
if (!ok) {
return false;
}
// Trim to PyTorch frame count
out_chunk.n_len = std::min(out_chunk.n_len, pt_frames);
output.push_back(std::move(out_chunk));
}
return true;
}
//
// mtmd_audio_streaming_istft implementation
//