# Qwen3.6-27B MXFP4 → GGUF & Metal Performance Analysis **Date**: 2026-04-30 **Bottom line**: The ~30% tg TPS gap (llama.cpp ~18 vs MLX ~24) is NOT from quant format or F16 accumulation. Root causes identified: (1) 1151 GPU dispatches per tick with high per-dispatch overhead, (2) 682 zero-ops (VIEW/RESHAPE/TRANSPOSE/PERMUTE) that still require encoding, (3) MUL_MAT kernel memory access patterns, (4) non-MUL_MAT ops (GET_ROWS, CPY, SET_ROWS) that read/write ~400 MB/tick on top of the 4.8 GB weight reads. --- ## 1. Model Architecture - **HF class**: `Qwen3_5ForConditionalGeneration`, **GGUF arch**: `qwen35` - 64 layers: 3 linear_attention + 1 full_attention per 4 (GatedDeltaNet + GQA) - Linear attn: k_heads=16, v_heads=48, k_dim=128, v_dim=128 - Full attn: heads=24, kv_heads=4, head_dim=256, partial RoPE factor=0.25 - V-head reordering required (grouped→tiled) - Config: `/Volumes/FastStore/hugging/Qwen3.6-27B/config.json` ## 2. MXFP4 Tensor Format **Sources**: `/Users/sleepy/.omlx/models/Qwen3.6-27B-mxfp4/` (3 shards, ~14.9GB) - Weights: `*.weight` dtype=U32, shape `[out, in/8]` — 8 nibbles per uint32 - Scales: `*.scales` dtype=U8, shape `[out, in/32]` — E4M3 unsigned (bias=7) - Non-quantized (BF16): layernorm, conv1d, A_log, dt_bias, norm, vision tower - Prefix: `language_model.model.layers.N` (differs from BF16 `model.language_model.layers.N`) - Full attention has separate Q/K/V projections (not fused `in_proj_qkv`) - `conv1d.weight` shape differs: BF16 `[10240, 1, 4]` vs MXFP4 `[10240, 4, 1]` - **GGML expects E8M0 scales**; MLX stores E4M3. Convert via `ue4m3_to_fp32()` + `fp32_to_ue4m3()` - Nibble packing differs from GptOss MoE format; must verify against BF16 ground truth ## 3. Converter Status - `GptOssModel` (line 12143): Only existing MXFP4 handler, MoE-specific, not reusable - `Qwen3_5TextModel` (line 5435): No MXFP4 support, will crash on MXFP4 weights - `_LinearAttentionVReorderBase` (line 5267): Has NVFP4 V-head reordering template - Detection works: `quantization.mode == "mxfp4"` from config.json ## 4. Benchmarks ### MLX (M4 Max) | Model | Quant | Size | tg 1K | tg 64K | |-------|-------|------|-------|--------| | Text-mxfp4-mlx | MXFP4 | 14.0 GB | 23.8 | 12.0 | | mxfp4 | MXFP4 | 14.9 GB | 23.7 | 12.4 | | 4bit | affine-4b | 15.7 GB | 22.6 | 12.0 | | oQ4 | oQ4 | 16.3 GB | 21.7 | 11.7 | ### llama.cpp (M4 Max) | Format | Size | tg 1K (q8) | tg 4K (q8) | |--------|------|-----------|-----------| | IQ4_XS | 15.4 GB | 18.3 | 18.4 | | Q4_0 | 15.8 GB | 18.0 | 18.0 | | IQ4_NL | 16.1 GB | 17.8 | 17.9 | **All three GGUF formats within 3% of each other. Bandwidth dominates. KV cache type (q4 vs q8) has zero effect.** MLX effective bandwidth at 15.7 GB: 22.6 × 15.7 = 355 GB/s. llama.cpp at 16.1 GB: 17.8 × 16.1 = 287 GB/s. **That's 19% less bandwidth utilization** — the kernel gap. ## 5. Root Cause: F32 Accumulation Every llama.cpp Metal kernel accumulates in F32. MLX accumulates in F16. Apple GPU does F16 FMA at 2× F32 rate. **Q4_0 decode kernel** (`block_q_n_dot_y`, line 3228): ```metal float d = qb_curr->d; float4 acc1 = {0.f}; acc1 += yl[i] * (qs[i/2] & mask); // F32×F32 return d * (sumy * -8.f + acc); // F32 final ``` **IQ4_NL decode kernel** (line 8884): ```metal shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; // F32 lookup sumf[row] += (float)xb.d * (acc1[0] + ...); // F32 accumulate ``` **MLX decode kernel** (`fp_qmv_fast`): ```metal half converted = as_type(ushort((bits & 7) << 9)); // 2 ops to half converted *= 16384.0; // half multiply return bits & 8 ? -converted : converted; // apply sign // Then: simdgroup_matrix accumulation = F16×F16 throughout ``` **Impact**: ~15-20% of the gap. The rest is threadgroup occupancy (N_R0=2 vs 4) and kernel dispatch overhead (~3-5%). ## 6. Format Alignment Analysis | Format | bpw | Block bytes | 4B-align | N_R0 | Lookup | Dequant | |--------|-----|-------------|----------|------|--------|---------| | Q4_0 | 4.50 | 18 | No | 4 | No | Fused dot | | IQ4_NL | 4.50 | 18 | No | 2 | Yes | Table+F32 | | IQ4_XS | ~4.25 | 136 | Yes | 2 | Yes | Sub-block scale | | MXFP4 | 4.25 | 17 | No | 2 | Yes | Shift+table+F32 | No format is "aligned" to Apple SIMD. Q4_0 is closest (simplest kernel, highest occupancy) but still accumulates F32. MXFP4 has the best compression (4.25 bpw) but needs E8M0 conversion. IQ4_XS is smallest GGUF but has the most complex kernel. **No new format needed.** F16 accumulation on existing formats is the path. ## 7. Key File Paths **Converter**: `convert_hf_to_gguf.py` lines 5435 (Qwen3_5TextModel), 5267 (_LinearAttentionVReorderBase), 734 (quant detection), 12143 (GptOss MXFP4) **Metal kernels**: `ggml/src/ggml-metal/ggml-metal.metal` lines 3228 (Q4_0 dot), 8850 (IQ4_NL mul_mv), 8960 (IQ4_XS mul_mv), 9069 (MXFP4 mul_mv), 597-625 (MXFP4 dequant) **Tuning**: `ggml/src/ggml-metal/ggml-metal-impl.h` — N_R0/N_SG constants **GGML format**: `ggml/src/ggml-common.h` line 204 (block_mxfp4), `gguf-py/gguf/quants.py` line 656 (MXFP4 quant) **Model architecture**: `src/models/qwen35.cpp`, `src/models/delta-net-base.cpp` **oMLX**: `/Applications/oMLX.app/Contents/Resources/omlx/patches/qwen3_5_attention.py` (RoPE fix), `gated_delta_advance.py` (cache fix), `turboquant_kv.py` (codebook KV), `specprefill.py` (sparse prefill) **MLX**: `.../mlx/include/mlx/backend/metal/kernels/fp4.h` (F16 E2M1), `fp_quantized.h` (MXFP4 GEMM) --- ## 8. IMPLEMENTATION PLAN: F16 ACCUMULATION KERNELS **⚠️ This is a non-tested proposal. Treat as pseudocode. Actual implementation may need adjustments for register pressure, threadgroup memory limits, and hardware-specific tuning. Always benchmark before/after on target hardware.** ### 8.1 Create `dequantize_*_half` variants **File**: `ggml/src/ggml-metal/ggml-metal.metal` For each quant format that will get F16 decode kernels, add a half-precision dequant function. These output `half4x4` instead of `float4x4`. **Q4_0** (currently at line 172): ```metal // EXISTING: void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { // outputs float4x4 } // PROPOSED: add dequantize_q4_0_half outputting half4x4 void dequantize_q4_0_half(device const block_q4_0 * xb, short il, thread half4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); const half d1 = (half)xb->d; const half d2 = d1 / (half)16.0h; const half md = (half)(-8.0h) * (half)xb->d; // Same nibble extraction, but multiply in half: for (int i = 0; i < 8; ++i) { reg[i/2][2*(i%2)+0] = d1 * (half)(qs[i] & mask0) + md; reg[i/2][2*(i%2)+1] = d2 * (half)(qs[i] & mask1) + md; } } ``` **IQ4_NL** (line 921): ```metal // PROPOSED: add dequantize_iq4_nl_half void dequantize_iq4_nl_half(device const block_iq4_nl * xb, short il, thread half4x4 & reg) { device const uint16_t * q4 = (device const uint16_t *)xb->qs; const half d = (half)xb->d; uint32_t aux32; thread const uint8_t * q8 = (thread const uint8_t *)&aux32; // Use half-precision lookup table: threadgroup half * shmem_h = (threadgroup half *)shmem; // Need to load kvalues as half in shmem first (see §8.3) for (int i = 0; i < 4; ++i) { aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; reg[i][0] = d * shmem_h[q8[0]]; reg[i][1] = d * shmem_h[q8[1]]; reg[i][2] = d * shmem_h[q8[2]]; reg[i][3] = d * shmem_h[q8[3]]; } } ``` **MXFP4** (line 597): Similar pattern, but E8M0→half conversion is `ushort(bits << 7)` → bfloat16, or `(uint32_t)bits << 23` → float32 then cast to half. ### 8.2 Create F16 mul_mv kernel variants **File**: `ggml/src/ggml-metal/ggml-metal.metal` For each format, create a dedicated F16 decode kernel that accumulates in `half` and reduces via `simd_sum` at the end. **Q4_0** — the most impactful starting point: ```metal template void kernel_mul_mv_q4_0_f16_impl( args_t args, device const char * src0, device const char * src1, device char * dst, threadgroup char * shmem, uint3 tgpig, ushort tiisg, ushort sgitg) { // Same structure as kernel_mul_mv_q4_0_f32_impl // BUT: accumulate in half, reduce at end: half sumf[NR0] = {0.0h}; // HALF accumulators // ... same block loading and nibble extraction ... // In inner loop, use half multiply: // half d = (half)xb.d; // half m = (half)xb.m; // half yl_h = (half)yl[i]; // sumf[row] += d * (nibble_val + m * sumy); // HALF FMA // Final reduction: cast to float for simd_sum for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum((float)sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; } } } ``` Key insight: The output is still F32 (single token decode produces one float per row). Only the intermediate accumulation is F16. This means `dst` type stays `float*` — no output format change needed. **IQ4_NL** (similar structure to existing `kernel_mul_mv_iq4_nl_f32_impl` at line 8850): - Same pattern: `half sumf[NR0]`, `half4` accumulators, lookup table loaded into threadgroup as `half` values, final `simd_sum((float)sumf[row])` - Must also create `dequantize_iq4_nl_t4_half` for the ext path (batch sizes 2-8) **IQ4_XS** (similar to existing line 8960): - Same pattern with `half` accumulators - Must also create `dequantize_iq4_xs_half` variants **MXFP4** (existing line 9069): - `dequantize_mxfp4_half`: E8M0→half via `ushort(bits << 7) → bfloat16` or bit-shift to float16 - `kvalues_mxfp4_h`: half-precision lookup table in threadgroup shared memory ### 8.3 Threadgroup shared memory for lookup tables **Current state**: Only `kernel_mul_mv_mxfp4_f32_impl` loads `kvalues_mxfp4_f` into threadgroup shared memory. IQ4_NL and IQ4_XS use `constexpr constant` arrays. **For F16 variants**, the lookup values should be loaded as `half` into threadgroup: ```metal // EXISTING (IQ4_NL, F32): threadgroup float * shmem_f32 = (threadgroup float *) shmem; shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); // PROPOSED (IQ4_NL, F16): threadgroup half * shmem_h = (threadgroup half *) shmem; shmem_h[tiisg] = (half)kvalues_iq4nl_f[tiisg%16]; // cast constant float to half threadgroup_barrier(mem_flags::mem_threadgroup); ``` Memory cost: 16 halves = 32 bytes (trivial). The existing IQ4_NL kernel already allocates `shmem` with `threadgroup(0)`. ### 8.4 Dispatch registration **File**: `ggml/src/ggml-metal/ggml-metal-ops.cpp` Add F16 variants alongside existing F32 variants. Each format needs: 1. A new `mul_mv` kernel name registered in `ggml_metal_op_mul_mat` 2. A new `mul_mv_ext` template instantiation for batch sizes 2-8 3. A new `mul_mm` template instantiation for batched GEMM Example for Q4_0: ```cpp // EXISTING: {"kernel_mul_mv_q4_0_f32", ...} {"kernel_mul_mv_ext_q4_0_f32_r1_2", ...} {"kernel_mul_mm_q4_0_f32", ...} {"kernel_mul_mm_q4_0_f16", ...} // f16 OUTPUT, still f32 dequant // PROPOSED: {"kernel_mul_mv_q4_0_f16", ...} // f16 DEQUANT + accumulation // ext variants r1_2 through r1_5 {"kernel_mul_mm_q4_0_f16_dequant", ...} // f16 dequant, f16 accumulation in threadgroup ``` ### 8.5 Kernel selection logic **File**: `ggml/src/ggml-metal/ggml-metal-ops.cpp` The dispatch logic (around line 2025) currently selects kernels based on `op->src[0]->type`. For F16 variants, add a runtime or compile-time flag to choose F16 accumulation kernels when available. Two approaches: 1. **Always prefer F16**: Replace existing kernels with F16 variants. The output is still F32 — no downstream changes needed. Risk: F16 may lose precision for very large models. 2. **Conditional selection**: Add a `GGML_METAL_F16_DEQUANT` flag (env var or build option) that selects F16 kernels when set. **Recommended**: Start with approach 1 for Q4_0 only (simplest kernel, well-tested format). If precision is fine, extend to other formats. Q4_0's arithmetic is trivially stable in F16 (scale + offset × nibble, range is well within F16 precision). ### 8.6 Tuning: N_R0/N_SG parameters **File**: `ggml/src/ggml-metal/ggml-metal-impl.h` Current values and proposed changes to benchmark: ```c // CURRENT: #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 // 8 rows/tg — already good #define N_R0_IQ4_NL 2 #define N_SG_IQ4_NL 2 // 4 rows/tg — try increasing #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 // 4 rows/tg — try increasing // PROPOSED BENCHMARK VARIANTS: #define N_R0_IQ4_NL 4 // try 8 rows/tg like Q4_0 #define N_R0_MXFP4 4 // try 8 rows/tg ``` These are compile-time constants. Create benchmark builds with each variant and measure tg TPS on M4 Max at 1K, 4K, and 16K contexts. ### 8.7 mul_mm (prefill) F16 path The mat-mat path already has `_f16` output variants (e.g., `kernel_mul_mm_q4_0_f16`). These dequantize to F32 in `float4x4` then store as `half2x4` for the simdgroup multiply. The F16 optimization here is to change the dequant functions to output `half4x4` directly, so the threadgroup memory stores half the data. This is lower priority than the decode (mul_mv) path because prefill is compute-bound and the mat-mat kernels already use `simdgroup_half8x8` for the accumulation stage. The main gain would be reduced threadgroup memory pressure. ### 8.8 Priority order 1. **Q4_0 F16 mul_mv kernel** — highest impact, simplest kernel (no lookup table), highest N_R0. File: `ggml-metal.metal` new `kernel_mul_mv_q4_0_f16_impl` 2. **IQ4_NL F16 mul_mv kernel** — second format, lookup table needs half shmem. File: `ggml-metal.metal` new `kernel_mul_mv_iq4_nl_f16_impl` 3. **IQ4_XS F16 mul_mv kernel** — third format. File: `ggml-metal.metal` new `kernel_mul_mv_iq4_xs_f16_impl` 4. **MXFP4 F16 mul_mv kernel** — fourth format, after converter works. File: `ggml-metal.metal` modify `kernel_mul_mv_mxfp4_f32_impl` 5. **N_R0 benchmarking** — `ggml-metal-impl.h`, try N_R0=4 for IQ4_NL and MXFP4 6. **mul_mm F16 dequant** — lower priority, mat-mat path already uses half simdgroup 7. **MXFP4 converter** — extend `Qwen3_5TextModel` per §3