diff --git a/ANALYSIS_QWEN3_5_MXFP4.md b/ANALYSIS_QWEN3_5_MXFP4.md new file mode 100644 index 000000000..ae96acf81 --- /dev/null +++ b/ANALYSIS_QWEN3_5_MXFP4.md @@ -0,0 +1,319 @@ +# Qwen3.6-27B MXFP4 → GGUF & Metal Performance Analysis + +**Date**: 2026-04-30 + +**Bottom line**: The ~30% tg TPS gap (llama.cpp ~18 vs MLX ~24) is NOT from quant format or F16 accumulation. Root causes identified: (1) 1151 GPU dispatches per tick with high per-dispatch overhead, (2) 682 zero-ops (VIEW/RESHAPE/TRANSPOSE/PERMUTE) that still require encoding, (3) MUL_MAT kernel memory access patterns, (4) non-MUL_MAT ops (GET_ROWS, CPY, SET_ROWS) that read/write ~400 MB/tick on top of the 4.8 GB weight reads. + +--- + +## 1. Model Architecture + +- **HF class**: `Qwen3_5ForConditionalGeneration`, **GGUF arch**: `qwen35` +- 64 layers: 3 linear_attention + 1 full_attention per 4 (GatedDeltaNet + GQA) +- Linear attn: k_heads=16, v_heads=48, k_dim=128, v_dim=128 +- Full attn: heads=24, kv_heads=4, head_dim=256, partial RoPE factor=0.25 +- V-head reordering required (grouped→tiled) +- Config: `/Volumes/FastStore/hugging/Qwen3.6-27B/config.json` + +## 2. MXFP4 Tensor Format + +**Sources**: `/Users/sleepy/.omlx/models/Qwen3.6-27B-mxfp4/` (3 shards, ~14.9GB) + +- Weights: `*.weight` dtype=U32, shape `[out, in/8]` — 8 nibbles per uint32 +- Scales: `*.scales` dtype=U8, shape `[out, in/32]` — E4M3 unsigned (bias=7) +- Non-quantized (BF16): layernorm, conv1d, A_log, dt_bias, norm, vision tower +- Prefix: `language_model.model.layers.N` (differs from BF16 `model.language_model.layers.N`) +- Full attention has separate Q/K/V projections (not fused `in_proj_qkv`) +- `conv1d.weight` shape differs: BF16 `[10240, 1, 4]` vs MXFP4 `[10240, 4, 1]` +- **GGML expects E8M0 scales**; MLX stores E4M3. Convert via `ue4m3_to_fp32()` + `fp32_to_ue4m3()` +- Nibble packing differs from GptOss MoE format; must verify against BF16 ground truth + +## 3. Converter Status + +- `GptOssModel` (line 12143): Only existing MXFP4 handler, MoE-specific, not reusable +- `Qwen3_5TextModel` (line 5435): No MXFP4 support, will crash on MXFP4 weights +- `_LinearAttentionVReorderBase` (line 5267): Has NVFP4 V-head reordering template +- Detection works: `quantization.mode == "mxfp4"` from config.json + +## 4. Benchmarks + +### MLX (M4 Max) + +| Model | Quant | Size | tg 1K | tg 64K | +|-------|-------|------|-------|--------| +| Text-mxfp4-mlx | MXFP4 | 14.0 GB | 23.8 | 12.0 | +| mxfp4 | MXFP4 | 14.9 GB | 23.7 | 12.4 | +| 4bit | affine-4b | 15.7 GB | 22.6 | 12.0 | +| oQ4 | oQ4 | 16.3 GB | 21.7 | 11.7 | + +### llama.cpp (M4 Max) + +| Format | Size | tg 1K (q8) | tg 4K (q8) | +|--------|------|-----------|-----------| +| IQ4_XS | 15.4 GB | 18.3 | 18.4 | +| Q4_0 | 15.8 GB | 18.0 | 18.0 | +| IQ4_NL | 16.1 GB | 17.8 | 17.9 | + +**All three GGUF formats within 3% of each other. Bandwidth dominates. KV cache type (q4 vs q8) has zero effect.** + +MLX effective bandwidth at 15.7 GB: 22.6 × 15.7 = 355 GB/s. llama.cpp at 16.1 GB: 17.8 × 16.1 = 287 GB/s. **That's 19% less bandwidth utilization** — the kernel gap. + +## 5. Root Cause: F32 Accumulation + +Every llama.cpp Metal kernel accumulates in F32. MLX accumulates in F16. Apple GPU does F16 FMA at 2× F32 rate. + +**Q4_0 decode kernel** (`block_q_n_dot_y`, line 3228): +```metal +float d = qb_curr->d; +float4 acc1 = {0.f}; +acc1 += yl[i] * (qs[i/2] & mask); // F32×F32 +return d * (sumy * -8.f + acc); // F32 final +``` + +**IQ4_NL decode kernel** (line 8884): +```metal +shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; // F32 lookup +sumf[row] += (float)xb.d * (acc1[0] + ...); // F32 accumulate +``` + +**MLX decode kernel** (`fp_qmv_fast`): +```metal +half converted = as_type(ushort((bits & 7) << 9)); // 2 ops to half +converted *= 16384.0; // half multiply +return bits & 8 ? -converted : converted; // apply sign +// Then: simdgroup_matrix accumulation = F16×F16 throughout +``` + +**Impact**: ~15-20% of the gap. The rest is threadgroup occupancy (N_R0=2 vs 4) and kernel dispatch overhead (~3-5%). + +## 6. Format Alignment Analysis + +| Format | bpw | Block bytes | 4B-align | N_R0 | Lookup | Dequant | +|--------|-----|-------------|----------|------|--------|---------| +| Q4_0 | 4.50 | 18 | No | 4 | No | Fused dot | +| IQ4_NL | 4.50 | 18 | No | 2 | Yes | Table+F32 | +| IQ4_XS | ~4.25 | 136 | Yes | 2 | Yes | Sub-block scale | +| MXFP4 | 4.25 | 17 | No | 2 | Yes | Shift+table+F32 | + +No format is "aligned" to Apple SIMD. Q4_0 is closest (simplest kernel, highest occupancy) but still accumulates F32. MXFP4 has the best compression (4.25 bpw) but needs E8M0 conversion. IQ4_XS is smallest GGUF but has the most complex kernel. + +**No new format needed.** F16 accumulation on existing formats is the path. + +## 7. Key File Paths + +**Converter**: `convert_hf_to_gguf.py` lines 5435 (Qwen3_5TextModel), 5267 (_LinearAttentionVReorderBase), 734 (quant detection), 12143 (GptOss MXFP4) + +**Metal kernels**: `ggml/src/ggml-metal/ggml-metal.metal` lines 3228 (Q4_0 dot), 8850 (IQ4_NL mul_mv), 8960 (IQ4_XS mul_mv), 9069 (MXFP4 mul_mv), 597-625 (MXFP4 dequant) + +**Tuning**: `ggml/src/ggml-metal/ggml-metal-impl.h` — N_R0/N_SG constants + +**GGML format**: `ggml/src/ggml-common.h` line 204 (block_mxfp4), `gguf-py/gguf/quants.py` line 656 (MXFP4 quant) + +**Model architecture**: `src/models/qwen35.cpp`, `src/models/delta-net-base.cpp` + +**oMLX**: `/Applications/oMLX.app/Contents/Resources/omlx/patches/qwen3_5_attention.py` (RoPE fix), `gated_delta_advance.py` (cache fix), `turboquant_kv.py` (codebook KV), `specprefill.py` (sparse prefill) + +**MLX**: `.../mlx/include/mlx/backend/metal/kernels/fp4.h` (F16 E2M1), `fp_quantized.h` (MXFP4 GEMM) + +--- + +## 8. IMPLEMENTATION PLAN: F16 ACCUMULATION KERNELS + +**⚠️ This is a non-tested proposal. Treat as pseudocode. Actual implementation may need adjustments for register pressure, threadgroup memory limits, and hardware-specific tuning. Always benchmark before/after on target hardware.** + +### 8.1 Create `dequantize_*_half` variants + +**File**: `ggml/src/ggml-metal/ggml-metal.metal` + +For each quant format that will get F16 decode kernels, add a half-precision dequant function. These output `half4x4` instead of `float4x4`. + +**Q4_0** (currently at line 172): +```metal +// EXISTING: +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { + // outputs float4x4 +} + +// PROPOSED: add dequantize_q4_0_half outputting half4x4 +void dequantize_q4_0_half(device const block_q4_0 * xb, short il, thread half4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const half d1 = (half)xb->d; + const half d2 = d1 / (half)16.0h; + const half md = (half)(-8.0h) * (half)xb->d; + // Same nibble extraction, but multiply in half: + for (int i = 0; i < 8; ++i) { + reg[i/2][2*(i%2)+0] = d1 * (half)(qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (half)(qs[i] & mask1) + md; + } +} +``` + +**IQ4_NL** (line 921): +```metal +// PROPOSED: add dequantize_iq4_nl_half +void dequantize_iq4_nl_half(device const block_iq4_nl * xb, short il, thread half4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const half d = (half)xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + // Use half-precision lookup table: + threadgroup half * shmem_h = (threadgroup half *)shmem; + // Need to load kvalues as half in shmem first (see §8.3) + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * shmem_h[q8[0]]; + reg[i][1] = d * shmem_h[q8[1]]; + reg[i][2] = d * shmem_h[q8[2]]; + reg[i][3] = d * shmem_h[q8[3]]; + } +} +``` + +**MXFP4** (line 597): Similar pattern, but E8M0→half conversion is `ushort(bits << 7)` → bfloat16, or `(uint32_t)bits << 23` → float32 then cast to half. + +### 8.2 Create F16 mul_mv kernel variants + +**File**: `ggml/src/ggml-metal/ggml-metal.metal` + +For each format, create a dedicated F16 decode kernel that accumulates in `half` and reduces via `simd_sum` at the end. + +**Q4_0** — the most impactful starting point: +```metal +template +void kernel_mul_mv_q4_0_f16_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + // Same structure as kernel_mul_mv_q4_0_f32_impl + // BUT: accumulate in half, reduce at end: + half sumf[NR0] = {0.0h}; // HALF accumulators + + // ... same block loading and nibble extraction ... + + // In inner loop, use half multiply: + // half d = (half)xb.d; + // half m = (half)xb.m; + // half yl_h = (half)yl[i]; + // sumf[row] += d * (nibble_val + m * sumy); // HALF FMA + + // Final reduction: cast to float for simd_sum + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum((float)sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} +``` + +Key insight: The output is still F32 (single token decode produces one float per row). Only the intermediate accumulation is F16. This means `dst` type stays `float*` — no output format change needed. + +**IQ4_NL** (similar structure to existing `kernel_mul_mv_iq4_nl_f32_impl` at line 8850): +- Same pattern: `half sumf[NR0]`, `half4` accumulators, lookup table loaded into threadgroup as `half` values, final `simd_sum((float)sumf[row])` +- Must also create `dequantize_iq4_nl_t4_half` for the ext path (batch sizes 2-8) + +**IQ4_XS** (similar to existing line 8960): +- Same pattern with `half` accumulators +- Must also create `dequantize_iq4_xs_half` variants + +**MXFP4** (existing line 9069): +- `dequantize_mxfp4_half`: E8M0→half via `ushort(bits << 7) → bfloat16` or bit-shift to float16 +- `kvalues_mxfp4_h`: half-precision lookup table in threadgroup shared memory + +### 8.3 Threadgroup shared memory for lookup tables + +**Current state**: Only `kernel_mul_mv_mxfp4_f32_impl` loads `kvalues_mxfp4_f` into threadgroup shared memory. IQ4_NL and IQ4_XS use `constexpr constant` arrays. + +**For F16 variants**, the lookup values should be loaded as `half` into threadgroup: + +```metal +// EXISTING (IQ4_NL, F32): +threadgroup float * shmem_f32 = (threadgroup float *) shmem; +shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; +threadgroup_barrier(mem_flags::mem_threadgroup); + +// PROPOSED (IQ4_NL, F16): +threadgroup half * shmem_h = (threadgroup half *) shmem; +shmem_h[tiisg] = (half)kvalues_iq4nl_f[tiisg%16]; // cast constant float to half +threadgroup_barrier(mem_flags::mem_threadgroup); +``` + +Memory cost: 16 halves = 32 bytes (trivial). The existing IQ4_NL kernel already allocates `shmem` with `threadgroup(0)`. + +### 8.4 Dispatch registration + +**File**: `ggml/src/ggml-metal/ggml-metal-ops.cpp` + +Add F16 variants alongside existing F32 variants. Each format needs: +1. A new `mul_mv` kernel name registered in `ggml_metal_op_mul_mat` +2. A new `mul_mv_ext` template instantiation for batch sizes 2-8 +3. A new `mul_mm` template instantiation for batched GEMM + +Example for Q4_0: +```cpp +// EXISTING: +{"kernel_mul_mv_q4_0_f32", ...} +{"kernel_mul_mv_ext_q4_0_f32_r1_2", ...} +{"kernel_mul_mm_q4_0_f32", ...} +{"kernel_mul_mm_q4_0_f16", ...} // f16 OUTPUT, still f32 dequant + +// PROPOSED: +{"kernel_mul_mv_q4_0_f16", ...} // f16 DEQUANT + accumulation +// ext variants r1_2 through r1_5 +{"kernel_mul_mm_q4_0_f16_dequant", ...} // f16 dequant, f16 accumulation in threadgroup +``` + +### 8.5 Kernel selection logic + +**File**: `ggml/src/ggml-metal/ggml-metal-ops.cpp` + +The dispatch logic (around line 2025) currently selects kernels based on `op->src[0]->type`. For F16 variants, add a runtime or compile-time flag to choose F16 accumulation kernels when available. + +Two approaches: +1. **Always prefer F16**: Replace existing kernels with F16 variants. The output is still F32 — no downstream changes needed. Risk: F16 may lose precision for very large models. +2. **Conditional selection**: Add a `GGML_METAL_F16_DEQUANT` flag (env var or build option) that selects F16 kernels when set. + +**Recommended**: Start with approach 1 for Q4_0 only (simplest kernel, well-tested format). If precision is fine, extend to other formats. Q4_0's arithmetic is trivially stable in F16 (scale + offset × nibble, range is well within F16 precision). + +### 8.6 Tuning: N_R0/N_SG parameters + +**File**: `ggml/src/ggml-metal/ggml-metal-impl.h` + +Current values and proposed changes to benchmark: + +```c +// CURRENT: +#define N_R0_Q4_0 4 +#define N_SG_Q4_0 2 // 8 rows/tg — already good +#define N_R0_IQ4_NL 2 +#define N_SG_IQ4_NL 2 // 4 rows/tg — try increasing +#define N_R0_MXFP4 2 +#define N_SG_MXFP4 2 // 4 rows/tg — try increasing + +// PROPOSED BENCHMARK VARIANTS: +#define N_R0_IQ4_NL 4 // try 8 rows/tg like Q4_0 +#define N_R0_MXFP4 4 // try 8 rows/tg +``` + +These are compile-time constants. Create benchmark builds with each variant and measure tg TPS on M4 Max at 1K, 4K, and 16K contexts. + +### 8.7 mul_mm (prefill) F16 path + +The mat-mat path already has `_f16` output variants (e.g., `kernel_mul_mm_q4_0_f16`). These dequantize to F32 in `float4x4` then store as `half2x4` for the simdgroup multiply. The F16 optimization here is to change the dequant functions to output `half4x4` directly, so the threadgroup memory stores half the data. + +This is lower priority than the decode (mul_mv) path because prefill is compute-bound and the mat-mat kernels already use `simdgroup_half8x8` for the accumulation stage. The main gain would be reduced threadgroup memory pressure. + +### 8.8 Priority order + +1. **Q4_0 F16 mul_mv kernel** — highest impact, simplest kernel (no lookup table), highest N_R0. File: `ggml-metal.metal` new `kernel_mul_mv_q4_0_f16_impl` +2. **IQ4_NL F16 mul_mv kernel** — second format, lookup table needs half shmem. File: `ggml-metal.metal` new `kernel_mul_mv_iq4_nl_f16_impl` +3. **IQ4_XS F16 mul_mv kernel** — third format. File: `ggml-metal.metal` new `kernel_mul_mv_iq4_xs_f16_impl` +4. **MXFP4 F16 mul_mv kernel** — fourth format, after converter works. File: `ggml-metal.metal` modify `kernel_mul_mv_mxfp4_f32_impl` +5. **N_R0 benchmarking** — `ggml-metal-impl.h`, try N_R0=4 for IQ4_NL and MXFP4 +6. **mul_mm F16 dequant** — lower priority, mat-mat path already uses half simdgroup +7. **MXFP4 converter** — extend `Qwen3_5TextModel` per §3 \ No newline at end of file diff --git a/BENCHMARKS.md b/BENCHMARKS.md new file mode 100644 index 000000000..8d073bb68 --- /dev/null +++ b/BENCHMARKS.md @@ -0,0 +1,156 @@ +# Baseline Benchmarks + +**Date**: 2026-04-30 +**Hardware**: Apple M4 Max +**Build**: 683c5acb9 (upstream main) +**Command**: `llama-bench -m MODEL -p 512 -t 1 -n 128 -o md -r 3` (pp512/tg128) + `llama-bench -m MODEL -p 1 -t 1 -n 4096 -o md -r 2` (tg4096) + +## pp512 (tokens/s) + +| Model | Q4_0 | IQ4_NL | IQ4_XS | +|-------|------|--------|--------| +| 4B | 1262.78 | 1252.70 | 1238.49 | +| 9B | 712.91 | 707.50 | 697.51 | + +## tg128 (tokens/s) + +| Model | Q4_0 | IQ4_NL | IQ4_XS | +|-------|------|--------|--------| +| 4B | 80.00 | 79.24 | 80.04 | +| 9B | 53.83 | 53.93 | 54.95 | + +## tg4096 (tokens/s) + +| Model | Q4_0 | IQ4_NL | IQ4_XS | +|-------|------|--------|--------| +| 4B | 76.09 | 75.24 | 45.23 | +| 9B | 52.06 | 51.95 | 38.51 | + +## Perplexity (Q4_0 4B, ctx=128) + +PPL = 2.2641 +/- 0.47327 + +## Effective bandwidth (9B models, tg128) + +| Format | Size (GiB) | tg TPS | Eq BW (GB/s) | +|--------|-----------|--------|-----------| +| Q4_0 | 5.00 | 53.83 | 289 | +| IQ4_NL | 4.99 | 53.93 | 289 | +| IQ4_XS | 4.80 | 54.95 | 283 | + +--- + +# F16 Accumulation Results + +**Date**: 2026-04-30 +**Build**: 683c5acb9 + F16 Q4_0 kernel (GGML_METAL_F16_ACCUM=1) + +## Q4_0 with F16 accumulation (tg4096) + +| Model | tg4096 F32 | tg4096 F16 | Delta | +|-------|-----------|-----------|-------| +| 4B | 76.09 | 76.15 | +0.08% | +| 9B | 52.06 | 51.94 | -0.23% | + +## Perplexity with F16 accumulation (Q4_0 4B, ctx=128) + +PPL = 2.2641 +/- 0.47327 (identical to baseline) + +**Conclusion**: F16 accumulation = zero perf improvement, zero quality impact. Reverted. + +--- + +# Graph Profile (tokgen decode) + +**Date**: 2026-04-30 +**Build**: 683c5acb9 (upstream main, clean) +**Tool**: `llama-eval-callback-profile` (custom, non-syncing cb_eval) +**Test**: p="The", n=32, ctx=256, t=1 + +**Key finding**: llama.cpp dispatches 1833 ops per decode tick (9B model). 682 are zero-ops (VIEW/RESHAPE/TRANSPOSE/PERMUTE — no GPU kernel). 1151 are actual GPU dispatches. This is a significant structural source of overhead. + +## 9B Q4_0 (52.9 tok/s, 1833 ops/tick, 1151 GPU dispatches/tick) + +| Op | PerTick | BytesIn/tk | BytesOut/tk | GPU? | Notes | +|----|--------|------------|-------------|------|-------| +| VIEW | 346 | 274 MB | 116 MB | NO | metadata only | +| RESHAPE | 288 | 108 MB | 108 MB | NO | metadata only | +| GET_ROWS | 99 | 678 MB | 53 MB | YES | token embed + DeltaNet state | +| CPY | 97 | 106 MB | 53 MB | YES | type conversion/layout | +| MUL_MAT | 249 | **4797 MB** | 7 MB | YES | weight matmuls (dominant) | +| GATED_DELTA_NET | 24 | 51 MB | 51 MB | YES | linear attention update | +| PERMUTE | 24 | 9 MB | 9 MB | NO | metadata only | +| SET_ROWS | 16 | 8 MB | 8 MB | YES | KV cache write | +| GLU | 32 | 3 MB | 2 MB | YES | FFN activation | +| MUL | 161 | 4 MB | 2 MB | YES | element-wise multiply | +| UNARY/SILU | 104 | 1 MB | 1 MB | YES | activation functions | +| RMS_NORM | 105 | 2 MB | 2 MB | YES | layer norms | +| ADD | 88 | 2 MB | 1 MB | YES | residual connections | +| SSM_CONV | 24 | 6 MB | 1 MB | YES | DeltaNet conv1d | +| L2_NORM | 48 | 0.4 MB | 0.4 MB | YES | q/k norm | +| ROPE | 16 | 0.2 MB | 0.2 MB | YES | rotary embeddings | +| FLASH_ATTN_EXT | 8 | 9 MB | 0.1 MB | YES | full attention (8 layers) | +| CONCAT | 24 | 3 MB | 3 MB | YES | tensor concatenation | +| SCALE | 48 | 0 | 0 | YES | scaling | +| CONT | 8 | 0.3 MB | 0.1 MB | YES | contiguous copy | +| TRANSPOSE | 24 | 1 MB | 1 MB | NO | metadata only | + +**Total data read per tick**: ~6.1 GB (MUL_MAT = 4.8 GB, GET_ROWS = 0.7 GB, CPY = 0.1 GB, rest ≈ 0.5 GB) + +## Context length impact (9B Q4_0) + +| Context | SET_ROWS | TPS | Notes | +|---------|----------|-----|-------| +| 256 | 8 MB | 52.9 | KV cache negligible | +| 2048 | 67 MB | 52.8 | Still negligible | +| 8192 | 268 MB | 52.5 | Still negligible | + +KV cache for 8 full-attention layers is tiny compared to MUL_MAT weight reads. The GatedDeltaNet state (51 MB) is larger but constant with context. + +## Architecture-specific notes + +Qwen3.5 has a hybrid architecture: 3 GatedDeltaNet + 1 full-attention per group of 4 layers. + +Per GatedDeltaNet layer: +- 3 input matmuls (qkv_a, alpha, beta) — Q8_0 ranked +- 1 z-gate matmul — Q4_0 +- 1 output projection matmul — Q4_0 +- 3 FFN matmuls (gate, up, out) — Q4_0 +- SSM_CONV, L2_NORM, SCALE, MUL for state update +- Total: ~7-8 MUL_MAT + SSM_CONV + misc + +Per full-attention layer: +- 3 input projections (Q, K, V) — Q4_0 +- 1 output projection — Q4_0 +- 3 FFN matmuls (gate, up, out) — Q4_0 +- ROPE, FLASH_ATTN_EXT +- Total: 7-8 MUL_MAT + +## Dispatch overhead analysis + +- 1833 ops/tick, 682 zero-ops (metadata), 1151 GPU dispatches +- At 52.9 tok/s → 18.9 ms/tick → 16.4 us per GPU dispatch average +- M4 Max Metal dispatch floor: ~3-5 us (from profiling) +- Dispatch overhead: 3.5-5.8 ms/tick (18-30% of total) +- MUL_MAT weight reads: 4.8 GB at observed 289 GB/s ≈ 16.6 ms (but pipelined with other ops) +- Other data: ~1.3 GB reads + ~0.4 GB writes ≈ 5-6 ms at 289 GB/s +- **Neither compute, bandwidth, nor dispatch is fully utilized** + +## Comparison with MLX + +MLX achieves ~355 GB/s effective bandwidth vs llama.cpp's ~289 GB/s on similar models (24% gap). + +Potential sources of gap: +1. **Kernel memory access patterns**: MLX uses contiguous weight reads, llama.cpp uses interleaved +2. **Dispatch efficiency**: 1151 GPU dispatches vs likely fewer in MLX (fewer view/reshape ops?) +3. **Non-MUL_MAT ops**: Nearly 600 MB/tick of reads for GET_ROWS/CPY/SET_ROWS — are these as efficient in llama.cpp? +4. **Graph optimization**: llama.cpp has many zero-ops (682 VIEW/RESHAPE/TRANSPOSE/PERMUTE) that still need encoding — can these be eliminated? + +## Profiling methodology + +- `llama-eval-callback-profile`: custom tool using `cb_eval` to observe ops without forcing sync +- `GGML_METAL_GRAPH_DEBUG=1` with `-v` flag: shows per-op graph structure (requires DEBUG log level) +- `GGML_METAL_CAPTURE_COMPUTE=2`: captures Xcode Instruments GPUtrace of 2nd compute call (first tokgen) +- Concurrency disabled: `GGML_METAL_CONCURRENCY_DISABLE=1` → ~53 → 52 tok/s (slightly worse) +- Fusion disabled: `GGML_METAL_FUSION_DISABLE=1` → negligible impact \ No newline at end of file diff --git a/GIT.md b/GIT.md new file mode 100644 index 000000000..b626dc6a5 --- /dev/null +++ b/GIT.md @@ -0,0 +1,170 @@ +# Git Workflow — llama.cpp M4 Max Performance Fork + +This is a private fork of [ggerganov/llama.cpp](https://github.com/ggerganov/llama.cpp) focused on Apple M4 Max Metal performance. All development happens on our Gitea instance. No changes ever touch upstream GitHub. + +## Remotes + +``` +origin → https://github.com/ggerganov/llama.cpp.git (read-only: git pull only) +gitea → ssh://sleepy@git.kokoham.com:2222/sleepy/llama.cpp.git (read/write) +``` + +- `origin` has no credentials — can pull but cannot push. Safe for agents. +- `gitea` is the working fork on our Gitea instance (SSH port 2222, user `sleepy`). + +## Syncing Upstream + +```bash +git fetch origin +git merge origin/master # fast-forward if clean +git push gitea master +``` + +Do this periodically. Conflicts should be rare since we only add tools/docs, not modify core code. + +## Branch Structure + +``` +master — always tracks upstream master (clean merge) +feature/ — active development branches (e.g., feature/mul-mat-contig-reads) +profile/ — profiling/measurement branches +fix/ — bug fixes found during profiling +exp/ — experimental, may be discarded +``` + +Branches are short-lived. Merge to master via PR, then delete. + +## Issue Tracking + +All work items are tracked as issues on https://git.kokoham.com/sleepy/llama.cpp/issues. + +Issue labels: +- `perf` — performance investigation +- `kernel` — Metal kernel changes +- `profiling` — measurement/tooling +- `doc` — documentation only +- `bug` — correctness issues +- `infra` — CI, build, repo setup + +## Pull Request Workflow + +1. Create branch from master: `git checkout -b feature/` +2. Make changes, commit with `[area] description` conventions (see below) +3. Push branch: `git push gitea feature/` +4. Create PR on Gitea targeting `master` +5. Before merge: build, benchmark (record in BENCHMARKS.md), perplexity check if kernel changed +6. Squash-merge to master + +## Commit Messages + +Format: `[area] short description (max 72 chars)` + +Areas: `metal`, `profile`, `docs`, `build`, `tool` + +Examples: +``` +[metal] add contiguous weight read path to Q4_0 mul_mat kernel +[profile] add per-op timing to metal encode loop +[docs] graph profile results for 9B Q4_0 at ctx=256 +[tool] llama-eval-callback-profile: non-syncing cb_eval profiler +``` + +## Agent Instructions + +When working autonomously, agents MUST: + +1. **Never push to `origin`** — `origin` has no credentials, this is a safety measure +2. **Create a branch** for any code change: `feature/-` +3. **Reference the issue** in commits: `[area] description (#123)` +4. **Run benchmarks** before/after kernel changes and record in BENCHMARKS.md +5. **Run perplexity** to verify correctness after any kernel change: + ```bash + ./build-build/bin/llama-perplexity -m MODEL.gguf -f /tmp/coherence_test.txt -t 1 --chunks 1 -c 128 + ``` +6. **Build succeeds** before pushing: + ```bash + cmake --build build-build -j$(sysctl -n hw.ncpu) + ``` +7. **Push branch** to gitea, then **create PR via Gitea API** (not via git push) + +## Build + +```bash +# Initial cmake (one time) +cmake -B build-build -DGGML_METAL=ON -DGGML_BLAS=ON -DGGML_ACCELERATE=ON + +# Incremental build +cmake --build build-build -j$(sysctl -n hw.ncpu) + +# Build specific target +cmake --build build-build --target llama-eval-callback-profile -j$(sysctl -n hw.ncpu) +``` + +## Benchmark Commands + +```bash +# Quick bench (pp + tg) +./build-build/bin/llama-bench -m MODEL.gguf -p 512 -t 1 -n 128 -o md -r 3 + +# Long tg bench (bandwidth-sensitive) +./build-build/bin/llama-bench -m MODEL.gguf -p 1 -t 1 -n 4096 -o md -r 2 + +# Perplexity +./build-build/bin/llama-perplexity -m MODEL.gguf -f /tmp/coherence_test.txt -t 1 --chunks 1 -c 128 +``` + +## Profiling Tools + +| Tool | What it does | +|------|-------------| +| `llama-eval-callback-profile` | Counts ops + bytes per decode tick (non-syncing cb_eval) | +| `GGML_METAL_GRAPH_DEBUG=1` | Prints per-op graph during compute (needs `-v` flag) | +| `GGML_METAL_GRAPH_DEBUG=2` | Also prints tensor shapes | +| `GGML_METAL_CAPTURE_COMPUTE=N` | Captures Nth compute call to Xcode Instruments GPUtrace | +| `GGML_METAL_CONCURRENCY_DISABLE=1` | Disable concurrent encoding (benchmark impact) | +| `GGML_METAL_FUSION_DISABLE=1` | Disable op fusion (benchmark impact) | + +## Model Files + +Located at `/Users/sleepy/.llama/models/`: + +``` +Qwen3.5-4B-Q4_0.gguf (2.40 GiB) +Qwen3.5-9B-Q4_0.gguf (5.00 GiB) +Qwen3.5-9B-IQ4_NL.gguf (4.99 GiB) +Qwen3.5-9B-IQ4_XS.gguf (4.80 GiB) +Qwen3.6-27B-Q4_0.gguf (14.70 GiB) +``` + +## Key Source Files + +``` +ggml/src/ggml-metal/ggml-metal.metal — Metal shader kernels (Q4_0 dot: line 3228) +ggml/src/ggml-metal/ggml-metal-device.cpp — Pipeline dispatch (get_pipeline_mul_mv: line 741) +ggml/src/ggml-metal/ggml-metal-ops.cpp — Op encoding (MUL_MAT: line 2257) +ggml/src/ggml-metal/ggml-metal-context.m — Graph compute (line 438) +ggml/src/ggml-metal/ggml-metal-impl.h — Tuning params (N_R0, N_SG) +examples/eval-callback/eval-callback-profile.cpp — Custom profiler tool +BENCHMARKS.md — All benchmark results +ANALYSIS_QWEN3_5_MXFP4.md — MXFP4 format analysis +``` + +## Gitea API + +Base: `https://git.kokoham.com/api/v1` +Token in `~/.gitea_token` (not committed). +Local API from server: `http://127.0.0.1:18431/api/v1` + +```bash +# Create issue +curl -X POST "http://127.0.0.1:18431/api/v1/repos/sleepy/llama.cpp/issues" \ + -H "Authorization: token $(cat ~/.gitea_token)" \ + -H "Content-Type: application/json" \ + -d '{"title":"...","body":"...","labels":["perf"]}' + +# Create PR +curl -X POST "http://127.0.0.1:18431/api/v1/repos/sleepy/llama.cpp/pulls" \ + -H "Authorization: token $(cat ~/.gitea_token)" \ + -H "Content-Type: application/json" \ + -d '{"title":"...","body":"...","head":"feature/xyz","base":"master"}' +``` diff --git a/examples/eval-callback/CMakeLists.txt b/examples/eval-callback/CMakeLists.txt index 63fbe59dc..023fc5b04 100644 --- a/examples/eval-callback/CMakeLists.txt +++ b/examples/eval-callback/CMakeLists.txt @@ -1,7 +1,12 @@ set(TARGET llama-eval-callback) add_executable(${TARGET} eval-callback.cpp) + +set(TARGET_PROFILE llama-eval-callback-profile) +add_executable(${TARGET_PROFILE} eval-callback-profile.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET_PROFILE} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET_PROFILE} PRIVATE cxx_std_17) target_compile_features(${TARGET} PRIVATE cxx_std_17) if(LLAMA_BUILD_TESTS) diff --git a/examples/eval-callback/eval-callback-profile.cpp b/examples/eval-callback/eval-callback-profile.cpp new file mode 100644 index 000000000..c53938dbf --- /dev/null +++ b/examples/eval-callback/eval-callback-profile.cpp @@ -0,0 +1,136 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "sampling.h" +#include +#include +#include +#include +#include +#include + +struct op_stat { + int count = 0; + double total_bytes_in = 0; + double total_bytes_out = 0; +}; + +static std::map g_op_stats; +static bool g_collect = false; + +static bool eval_callback(struct ggml_tensor * t, bool ask, void * user_data) { + (void)user_data; + if (!g_collect) return false; + if (t->op == GGML_OP_NONE) return false; + if (!ask) return false; // after compute, just observe + + auto & s = g_op_stats[t->op]; + s.count++; + s.total_bytes_out += (double)ggml_nbytes(t); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (t->src[i]) s.total_bytes_in += (double)ggml_nbytes(t->src[i]); + } + + return false; // no sync needed +} + +int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + + common_params params; + params.n_predict = 32; + params.cb_eval = eval_callback; + params.cb_eval_user_data = nullptr; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + common_init(); + + auto llama_init = common_init_from_params(params); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (!model || !ctx) { + LOG_ERR("failed to init\n"); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + std::string prompt = params.prompt.empty() ? "The" : params.prompt; + std::vector tokens = common_tokenize(ctx, prompt, add_bos, true); + + // Prefill + LOG_INF("Prefilling %zu tokens...\n", tokens.size()); + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())) != 0) { + LOG_ERR("prefill failed\n"); + return 1; + } + + // Start collecting op stats + g_collect = true; + + int n_gen = params.n_predict; + LOG_INF("Generating %d tokens...\n", n_gen); + + llama_token new_token = common_sampler_sample(llama_init->sampler(0), ctx, -1); + + auto t_start = ggml_time_us(); + + for (int i = 0; i < n_gen; i++) { + if (llama_decode(ctx, llama_batch_get_one(&new_token, 1)) != 0) { + LOG_ERR("decode failed at step %d\n", i); + break; + } + new_token = common_sampler_sample(llama_init->sampler(0), ctx, -1); + } + + auto t_end = ggml_time_us(); + + double total_ms = (t_end - t_start) / 1000.0; + double tps = n_gen / (total_ms / 1000.0); + + LOG("\n=== Tokgen Graph Profile ===\n"); + LOG("Total: %d tokens in %.1f ms (%.1f tok/s)\n", n_gen, total_ms, tps); + LOG("Total ops per decode step: %d\n\n", g_op_stats.empty() ? 0 : + std::accumulate(g_op_stats.begin(), g_op_stats.end(), 0, [](int s, const auto & p) { return s + p.second.count; }) / n_gen); + + double total_bytes = 0; + for (auto & [op, s] : g_op_stats) total_bytes += s.total_bytes_out; + + // sort by bytes_out descending + std::vector> sorted(g_op_stats.begin(), g_op_stats.end()); + std::sort(sorted.begin(), sorted.end(), [](auto & a, auto & b) { + return a.second.total_bytes_out > b.second.total_bytes_out; + }); + + LOG("%-20s %8s %14s %14s %10s\n", "Op", "PerTick", "BytesIn/tk", "BytesOut/tk", "Frac%"); + LOG("%-20s %8s %14s %14s %10s\n", "---", "---", "---", "---", "---"); + + for (auto & [op, s] : sorted) { + double per_tick = (double)s.count / n_gen; + double bytes_in_per = s.total_bytes_in / n_gen; + double bytes_out_per = s.total_bytes_out / n_gen; + double frac = s.total_bytes_out / total_bytes * 100; + LOG("%-20s %8.1f %14.0f %14.0f %9.1f%%\n", + ggml_op_name(op), per_tick, bytes_in_per, bytes_out_per, frac); + } + + // Estimate effective bandwidth if dominated by memory ops + double model_gib = (double)llama_model_size(model) / (1ull << 30); + double eff_bw = tps * model_gib * 1.074; + LOG("\nModel size: %.2f GiB\n", model_gib); + LOG("Effective BW (model * tps): %.0f GB/s\n", eff_bw); + LOG("M4 Max theoretical BW: ~400 GB/s\n"); + LOG("BW utilization: %.1f%%\n", eff_bw / 400 * 100); + + llama_perf_context_print(ctx); + llama_backend_free(); + return 0; +}