mtmd: add Gemma 4 audio conformer encoder support (#21421)
* mtmd: add Gemma 4 audio conformer encoder support Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer. Architecture: - 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm - Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm - Full self-attention with sinusoidal RPE and sliding window mask (24) - Logit softcapping at 50.0, ClippableLinear clamping - Output: 1024 → 1536 → RMSNorm → multimodal embedder Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a): - HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3 - Standard periodic Hann window (320 samples), zero-padded to FFT size - Semicausal left-padding (frame_length/2 samples) - Frame count matched to PyTorch (unfold formula) - No pre-emphasis, no Whisper-style normalization - Mel cosine similarity vs PyTorch: 0.9998 Key fixes: - Tensor loading dedup: prevent get_tensor() from creating duplicate entries in ctx_data. Fixed with std::set guard. - ClippableLinear clamp_info loading moved after per-layer tensors. - Sliding window mask (24 positions) matching PyTorch context_size. - Skip Whisper normalization for Gemma4 mel output. Tested on E2B and E4B with CPU and Vulkan backends. Transcribes: "Glad to see things are going well and business is starting to pick up" (matching ground truth). Ref: #21325
This commit is contained in:
+126
-22
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
// some of the code here is copied from whisper.cpp
|
||||
|
||||
@@ -37,23 +38,36 @@ void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
|
||||
float fmin,
|
||||
float fmax,
|
||||
bool slaney_area_norm,
|
||||
float scale) {
|
||||
float scale,
|
||||
bool use_htk) {
|
||||
GGML_ASSERT(n_mel > 0 && n_fft > 1);
|
||||
if (fmax <= 0.0f) {
|
||||
fmax = 0.5f * sample_rate;
|
||||
}
|
||||
|
||||
// Slaney scale (matches librosa default)
|
||||
const double min_log_hz = 1000.0;
|
||||
const double lin_slope = 3 / 200.;
|
||||
const double min_log_mel = min_log_hz * lin_slope;
|
||||
const double log_step = log(6.4) / 27.0;
|
||||
auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
|
||||
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
|
||||
};
|
||||
auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
|
||||
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
|
||||
};
|
||||
std::function<double(double)> hz_to_mel;
|
||||
std::function<double(double)> mel_to_hz;
|
||||
|
||||
if (use_htk) {
|
||||
hz_to_mel = [](const double f_hz) -> double {
|
||||
return 2595.0 * log10(1.0 + f_hz / 700.0);
|
||||
};
|
||||
mel_to_hz = [](const double m) -> double {
|
||||
return 700.0 * (pow(10.0, m / 2595.0) - 1.0);
|
||||
};
|
||||
} else {
|
||||
// Slaney scale (matches librosa default)
|
||||
const double min_log_hz = 1000.0;
|
||||
const double lin_slope = 3 / 200.;
|
||||
const double min_log_mel = min_log_hz * lin_slope;
|
||||
const double log_step = log(6.4) / 27.0;
|
||||
hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
|
||||
return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
|
||||
};
|
||||
mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
|
||||
return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
|
||||
};
|
||||
}
|
||||
|
||||
// infer N_fft from n_fft_bins
|
||||
const double bin_hz_step = double(sample_rate) / double(n_fft);
|
||||
@@ -257,10 +271,13 @@ struct filter_params {
|
||||
int32_t hann_window_size;
|
||||
int32_t hop_length;
|
||||
int32_t sample_rate;
|
||||
bool center_padding = false;
|
||||
float preemph = 0.f;
|
||||
bool no_padding = false;
|
||||
bool center_padding = false;
|
||||
float preemph = 0.f;
|
||||
bool use_natural_log = false;
|
||||
bool norm_per_feature = false;
|
||||
bool use_magnitude = false; // |X| instead of |X|^2
|
||||
float mel_floor = 5.960464477539063e-08f;
|
||||
};
|
||||
|
||||
static void log_mel_spectrogram_worker_thread(int ith,
|
||||
@@ -301,10 +318,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
|
||||
// FFT
|
||||
fft(cache, fft_in.data(), frame_size, fft_out.data());
|
||||
|
||||
// Calculate modulus^2 of complex numbers
|
||||
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
||||
// Calculate modulus^2 (power) or modulus (magnitude)
|
||||
for (int j = 0; j < n_fft_bins; j++) {
|
||||
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||
float power = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||
fft_out[j] = params.use_magnitude ? sqrtf(power) : power;
|
||||
}
|
||||
|
||||
// mel spectrogram
|
||||
@@ -324,9 +341,10 @@ static void log_mel_spectrogram_worker_thread(int ith,
|
||||
for (; k < n_fft_bins; k++) {
|
||||
sum += fft_out[k] * filters.data[j * n_fft_bins + k];
|
||||
}
|
||||
sum = std::max(sum, (double)params.mel_floor);
|
||||
sum = params.use_natural_log
|
||||
? log(sum + 5.960464477539063e-08)
|
||||
: log10(std::max(sum, 1e-10));
|
||||
? log(sum)
|
||||
: log10(sum);
|
||||
out.data[j * out.n_len + i] = sum;
|
||||
}
|
||||
}
|
||||
@@ -360,7 +378,12 @@ static bool log_mel_spectrogram(
|
||||
|
||||
// Padding
|
||||
std::vector<float> samples_padded;
|
||||
if (params.center_padding) {
|
||||
if (params.no_padding) {
|
||||
// no padding, use samples as-is
|
||||
samples_padded = std::vector<float>(samples, samples + n_samples);
|
||||
samples = samples_padded.data();
|
||||
n_samples = samples_padded.size();
|
||||
} else if (params.center_padding) {
|
||||
const auto pad_amount = frame_size / 2;
|
||||
samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
|
||||
std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
|
||||
@@ -464,8 +487,8 @@ static bool log_mel_spectrogram(
|
||||
out.data[i * out.n_len + j] = 0.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// clamping and normalization
|
||||
} else if (!params.no_padding) {
|
||||
// Whisper-style clamping and normalization (NOT used by Gemma4)
|
||||
double mmax = -1e20;
|
||||
for (int i = 0; i < out.n_mel*out.n_len; i++) {
|
||||
if (out.data[i] > mmax) {
|
||||
@@ -627,6 +650,87 @@ bool mtmd_audio_preprocessor_conformer::preprocess(const float *
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// mtmd_audio_preprocessor_gemma4a
|
||||
//
|
||||
|
||||
void mtmd_audio_preprocessor_gemma4a::initialize() {
|
||||
cache.fill_sin_cos_table(hparams.audio_n_fft);
|
||||
|
||||
// Standard periodic Hann window, zero-padded to FFT size
|
||||
cache.hann_window.assign(hparams.audio_n_fft, 0.0f);
|
||||
for (uint32_t i = 0; i < (uint32_t)hparams.audio_window_len; i++) {
|
||||
cache.hann_window[i] = 0.5f - 0.5f * cosf((2.0f * (float)M_PI * i) / hparams.audio_window_len);
|
||||
}
|
||||
|
||||
// HTK mel scale, no Slaney area normalization
|
||||
cache.fill_mel_filterbank_matrix(
|
||||
hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate,
|
||||
0.0f, hparams.audio_sample_rate / 2.0f,
|
||||
/*slaney_area_norm=*/ false,
|
||||
/*scale=*/ 1.0f,
|
||||
/*use_htk=*/ true
|
||||
);
|
||||
}
|
||||
|
||||
bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * samples,
|
||||
size_t n_samples,
|
||||
std::vector<mtmd_audio_mel> & output) {
|
||||
if (n_samples == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_ASSERT(!cache.sin_vals.empty());
|
||||
GGML_ASSERT(!cache.cos_vals.empty());
|
||||
GGML_ASSERT(!cache.filters.data.empty());
|
||||
|
||||
filter_params params;
|
||||
params.n_mel = hparams.n_mel_bins;
|
||||
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
|
||||
params.hann_window_size = hparams.audio_n_fft; // window is zero-padded to FFT size
|
||||
params.hop_length = hparams.audio_hop_len;
|
||||
params.sample_rate = hparams.audio_sample_rate;
|
||||
params.no_padding = true;
|
||||
params.center_padding = false;
|
||||
params.preemph = 0.0f;
|
||||
params.use_natural_log = true;
|
||||
params.use_magnitude = true;
|
||||
params.mel_floor = 0.001f;
|
||||
params.norm_per_feature = false;
|
||||
|
||||
// Split into 30-second chunks (model context limit, ~750 tokens each)
|
||||
const size_t chunk_samples = 30 * hparams.audio_sample_rate;
|
||||
for (size_t off = 0; off < n_samples; off += chunk_samples) {
|
||||
const float * chunk_ptr = samples + off;
|
||||
size_t chunk_len = std::min(chunk_samples, n_samples - off);
|
||||
|
||||
// Semicausal left-padding + right-padding to match PyTorch frame count
|
||||
const int pad_left = hparams.audio_window_len / 2;
|
||||
const int fft_size = hparams.audio_n_fft;
|
||||
const int hop = hparams.audio_hop_len;
|
||||
const int n_with_left = (int)chunk_len + pad_left;
|
||||
// PyTorch: unfold(size=frame_length+1, step=hop) on semicausal-padded waveform
|
||||
const int pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1;
|
||||
const int n_padded_needed = (pt_frames - 1) * hop + fft_size;
|
||||
const int total_pad = std::max((int)(n_padded_needed - (int)chunk_len), pad_left);
|
||||
std::vector<float> padded_samples(total_pad + chunk_len, 0.0f);
|
||||
std::copy(chunk_ptr, chunk_ptr + chunk_len, padded_samples.data() + pad_left);
|
||||
|
||||
mtmd_audio_mel out_chunk;
|
||||
bool ok = log_mel_spectrogram(padded_samples.data(), padded_samples.size(), 4, params, cache, out_chunk);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Trim to PyTorch frame count
|
||||
out_chunk.n_len = std::min(out_chunk.n_len, pt_frames);
|
||||
|
||||
output.push_back(std::move(out_chunk));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// mtmd_audio_streaming_istft implementation
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user