45c3aad453
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7 - Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training - Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution - Add analysis/ folder with cross-model comparisons and per-challenge deep dives - Add deploy_challenges.sh script - Expand .gitignore to exclude Python envs, ML weights, and build artifacts
342 lines
13 KiB
Python
342 lines
13 KiB
Python
"""
|
|
Ternary training for Qwen3-0.6B on train_data.txt.
|
|
|
|
Approach:
|
|
- Load Qwen3-0.6B via mlx_lm (latent weights start as pretrained FP weights).
|
|
- Replace every nn.Linear under the transformer blocks (Q/K/V/O, gate/up/down)
|
|
with a TernaryLinear module.
|
|
- The lm_head is tied to the input embedding for Qwen3-0.6B; we leave the
|
|
embedding/lm_head non-ternary (BitNet b1.58 lineage keeps embeddings in
|
|
higher precision; ternarizing a 151936x1024 lookup table also doesn't
|
|
save FLOPs since it's a gather, not a matmul).
|
|
- Group-wise scale s = mean(|W|) over groups of 128 along the input dim.
|
|
- Forward uses W_t = clip(round(W / s), -1, 1) * s. STE makes the gradient
|
|
of the projection an identity so dL/dW_latent = dL/dW_ternary.
|
|
- Train with AdamW for >=200 steps on train_data.txt, then verify, generate,
|
|
and compute held-out perplexity.
|
|
"""
|
|
|
|
import json
|
|
import math
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.optimizers as optim
|
|
from mlx.utils import tree_flatten, tree_map
|
|
from mlx_lm import load
|
|
from mlx_lm.generate import generate
|
|
|
|
|
|
import os
|
|
|
|
GROUP_SIZE = 128
|
|
SEQ_LEN = int(os.environ.get("SEQ_LEN", 256))
|
|
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 4))
|
|
NUM_STEPS = int(os.environ.get("NUM_STEPS", 250))
|
|
LR = float(os.environ.get("LR", 5e-4))
|
|
WARMUP_STEPS = int(os.environ.get("WARMUP_STEPS", 30))
|
|
SEED = int(os.environ.get("SEED", 0))
|
|
DATA_PATH = Path(__file__).parent / "train_data.txt"
|
|
|
|
|
|
# -------------------------- TernaryLinear --------------------------
|
|
|
|
class TernaryLinear(nn.Module):
|
|
"""Linear layer with weight stored in float (latent) but projected to
|
|
a group-wise ternary {-s, 0, +s} representation in the forward pass.
|
|
|
|
Gradient flows through the projection via a Straight-Through Estimator:
|
|
W_eff = W + stop_gradient(W_t - W)
|
|
so dL/dW = dL/dW_eff at the W_eff = W_t point.
|
|
"""
|
|
|
|
def __init__(self, in_features: int, out_features: int, group_size: int = GROUP_SIZE):
|
|
super().__init__()
|
|
if in_features % group_size != 0:
|
|
raise ValueError(f"in_features={in_features} not divisible by group_size={group_size}")
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.group_size = group_size
|
|
scale = in_features ** -0.5
|
|
self.weight = mx.random.normal((out_features, in_features)) * scale
|
|
|
|
@staticmethod
|
|
def ternarize(W: mx.array, group_size: int) -> mx.array:
|
|
out, in_ = W.shape
|
|
Wg = W.reshape(out, in_ // group_size, group_size)
|
|
# Per-group scale: mean(|W|).
|
|
s = mx.mean(mx.abs(Wg), axis=-1, keepdims=True)
|
|
s = mx.maximum(s, 1e-8)
|
|
q = mx.clip(mx.round(Wg / s), -1, 1)
|
|
Wt = q * s
|
|
return Wt.reshape(out, in_)
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
Wt = self.ternarize(self.weight, self.group_size)
|
|
# STE: identity gradient through the projection.
|
|
W_eff = self.weight + mx.stop_gradient(Wt - self.weight)
|
|
return x @ W_eff.T
|
|
|
|
|
|
def replace_linear_with_ternary(parent: nn.Module):
|
|
"""Recursively walk `parent` and swap nn.Linear children for TernaryLinear,
|
|
transferring the pretrained weight into the latent slot.
|
|
"""
|
|
children = parent.children()
|
|
for name, child in children.items():
|
|
if isinstance(child, nn.Linear) and not isinstance(child, TernaryLinear):
|
|
in_f = child.weight.shape[1]
|
|
out_f = child.weight.shape[0]
|
|
tl = TernaryLinear(in_f, out_f, group_size=GROUP_SIZE)
|
|
tl.weight = child.weight
|
|
setattr(parent, name, tl)
|
|
elif isinstance(child, nn.Module):
|
|
replace_linear_with_ternary(child)
|
|
elif isinstance(child, list):
|
|
for item in child:
|
|
if isinstance(item, nn.Module):
|
|
replace_linear_with_ternary(item)
|
|
|
|
|
|
def count_ternary_layers(model):
|
|
out = collect_ternary_modules(model)
|
|
return len(out), sum(tl.weight.size for tl in out)
|
|
|
|
|
|
def collect_ternary_modules(model):
|
|
"""Walk the tree and return all TernaryLinear modules."""
|
|
found = []
|
|
|
|
def walk(mod):
|
|
for _, child in mod.children().items():
|
|
if isinstance(child, TernaryLinear):
|
|
found.append(child)
|
|
elif isinstance(child, nn.Module):
|
|
walk(child)
|
|
elif isinstance(child, list):
|
|
for item in child:
|
|
if isinstance(item, nn.Module):
|
|
walk(item)
|
|
|
|
walk(model)
|
|
return found
|
|
|
|
|
|
# -------------------------- Data --------------------------
|
|
|
|
def load_tokens(tokenizer, path: Path):
|
|
text = path.read_text()
|
|
# Split by blank lines into "paragraphs" so val/train are coherent units.
|
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
|
# Hold out the last 10% as validation.
|
|
n_val = max(1, len(paragraphs) // 10)
|
|
train_paras = paragraphs[:-n_val]
|
|
val_paras = paragraphs[-n_val:]
|
|
|
|
train_ids = tokenizer.encode("\n\n".join(train_paras))
|
|
val_ids = tokenizer.encode("\n\n".join(val_paras))
|
|
return mx.array(train_ids, dtype=mx.int32), mx.array(val_ids, dtype=mx.int32)
|
|
|
|
|
|
def sample_batch(tokens: mx.array, seq_len: int, batch_size: int, rng):
|
|
n = tokens.shape[0]
|
|
max_start = n - seq_len - 1
|
|
starts = rng.integers(0, max_start, size=(batch_size,))
|
|
x = mx.stack([tokens[s : s + seq_len] for s in starts.tolist()])
|
|
y = mx.stack([tokens[s + 1 : s + 1 + seq_len] for s in starts.tolist()])
|
|
return x, y
|
|
|
|
|
|
# -------------------------- Loss --------------------------
|
|
|
|
def loss_fn(model, x, y):
|
|
logits = model(x)
|
|
logits = logits.astype(mx.float32)
|
|
# Cross-entropy over vocab.
|
|
log_probs = nn.log_softmax(logits, axis=-1)
|
|
# Gather target log-probs.
|
|
B, L, V = log_probs.shape
|
|
flat = log_probs.reshape(B * L, V)
|
|
flat_y = y.reshape(B * L)
|
|
nll = -flat[mx.arange(B * L), flat_y]
|
|
return nll.mean()
|
|
|
|
|
|
# -------------------------- LR schedule --------------------------
|
|
|
|
def lr_at(step):
|
|
if step < WARMUP_STEPS:
|
|
return LR * (step + 1) / WARMUP_STEPS
|
|
# Cosine decay to 10% of peak.
|
|
progress = (step - WARMUP_STEPS) / max(1, NUM_STEPS - WARMUP_STEPS)
|
|
progress = min(1.0, progress)
|
|
return LR * (0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)))
|
|
|
|
|
|
# -------------------------- Verification --------------------------
|
|
|
|
def verify_ternary(model):
|
|
"""For every TernaryLinear, recompute the projection and check it is
|
|
exactly {-s, 0, +s} per-group.
|
|
"""
|
|
bad = 0
|
|
total = 0
|
|
max_err = 0.0
|
|
for tl in collect_ternary_modules(model):
|
|
W = tl.weight
|
|
Wt = TernaryLinear.ternarize(W, tl.group_size)
|
|
out, in_ = Wt.shape
|
|
Wg = Wt.reshape(out, in_ // tl.group_size, tl.group_size)
|
|
s = mx.mean(mx.abs(Wg), axis=-1, keepdims=True) # recompute scale ON the projected tensor
|
|
# The recomputed s should equal the s used in the projection for
|
|
# values that were already ternary; for verification we compare
|
|
# Wt / s_orig to integers in {-1, 0, +1}.
|
|
s_orig = mx.maximum(mx.mean(mx.abs(W.reshape(out, in_ // tl.group_size, tl.group_size)), axis=-1, keepdims=True), 1e-8)
|
|
q = Wg / s_orig
|
|
# Distance to nearest integer in {-1, 0, +1}.
|
|
nearest = mx.clip(mx.round(q), -1, 1)
|
|
err = mx.max(mx.abs(q - nearest)).item()
|
|
max_err = max(max_err, err)
|
|
out_size = Wt.size
|
|
bad_here = (mx.abs(q - nearest) > 1e-5).sum().item()
|
|
bad += bad_here
|
|
total += out_size
|
|
return bad, total, max_err
|
|
|
|
|
|
# -------------------------- Main --------------------------
|
|
|
|
def main():
|
|
mx.random.seed(SEED)
|
|
import numpy as np
|
|
rng = np.random.default_rng(SEED)
|
|
|
|
print("[1/6] Loading Qwen3-0.6B via mlx_lm…", flush=True)
|
|
t0 = time.time()
|
|
model, tokenizer = load("Qwen/Qwen3-0.6B")
|
|
print(f" loaded in {time.time()-t0:.1f}s", flush=True)
|
|
|
|
print("[2/6] Replacing nn.Linear with TernaryLinear (Q/K/V/O + gate/up/down)…", flush=True)
|
|
# Walk into model.model.layers (transformer blocks) only.
|
|
for layer in model.model.layers:
|
|
replace_linear_with_ternary(layer)
|
|
n_ternary, n_ternary_params = count_ternary_layers(model)
|
|
print(f" replaced {n_ternary} linear layers ({n_ternary_params/1e6:.1f}M ternary params)", flush=True)
|
|
|
|
# Force the parameter tree to materialize on device with proper dtypes.
|
|
# Latent weights are kept in float32 for stable optimizer math.
|
|
def to_f32(p):
|
|
if isinstance(p, mx.array) and p.dtype != mx.float32 and p.ndim >= 1:
|
|
return p.astype(mx.float32)
|
|
return p
|
|
# Only cast the ternary latent weights to fp32; leave norms/embeddings alone.
|
|
for tl in collect_ternary_modules(model):
|
|
tl.weight = tl.weight.astype(mx.float32)
|
|
|
|
# Quick smoke test: forward pass on 1 token.
|
|
test_in = mx.array([[tokenizer.eos_token_id]])
|
|
_ = model(test_in)
|
|
mx.eval(_)
|
|
print(" smoke forward pass ok", flush=True)
|
|
|
|
print("[3/6] Tokenizing train_data.txt…", flush=True)
|
|
train_tokens, val_tokens = load_tokens(tokenizer, DATA_PATH)
|
|
print(f" train tokens: {train_tokens.size}, val tokens: {val_tokens.size}", flush=True)
|
|
|
|
print(f"[4/6] Training for {NUM_STEPS} steps (batch={BATCH_SIZE}, seq_len={SEQ_LEN}, lr_peak={LR})…", flush=True)
|
|
|
|
optimizer = optim.AdamW(learning_rate=LR, weight_decay=0.0, betas=(0.9, 0.95))
|
|
|
|
loss_and_grad = nn.value_and_grad(model, loss_fn)
|
|
|
|
losses = []
|
|
t0 = time.time()
|
|
for step in range(NUM_STEPS):
|
|
x, y = sample_batch(train_tokens, SEQ_LEN, BATCH_SIZE, rng)
|
|
optimizer.learning_rate = lr_at(step)
|
|
loss, grads = loss_and_grad(model, x, y)
|
|
optimizer.update(model, grads)
|
|
mx.eval(model.parameters(), optimizer.state, loss)
|
|
l = loss.item()
|
|
losses.append(l)
|
|
if step % 10 == 0 or step == NUM_STEPS - 1:
|
|
elapsed = time.time() - t0
|
|
print(f" step {step:4d}/{NUM_STEPS} loss={l:.4f} lr={optimizer.learning_rate.item():.2e} ({elapsed:.0f}s)", flush=True)
|
|
|
|
print(f" first 5 loss avg: {sum(losses[:5])/5:.4f}", flush=True)
|
|
print(f" last 20 loss avg: {sum(losses[-20:])/20:.4f}", flush=True)
|
|
|
|
print("[5/6] Verifying ternary projection…", flush=True)
|
|
bad, total, max_err = verify_ternary(model)
|
|
print(f" bad weights: {bad}/{total} max projection error: {max_err:.2e}", flush=True)
|
|
ternary_ok = bad == 0 and max_err < 1e-4
|
|
print(f" TERNARY OK: {ternary_ok}", flush=True)
|
|
|
|
print("[6/6] Validation perplexity + samples…", flush=True)
|
|
# Compute val perplexity over non-overlapping windows.
|
|
n_val = val_tokens.size
|
|
win = SEQ_LEN
|
|
n_windows = max(1, (n_val - 1) // win)
|
|
nll_sum = 0.0
|
|
tok_count = 0
|
|
for i in range(n_windows):
|
|
s = i * win
|
|
x = val_tokens[s : s + win][None, :]
|
|
y = val_tokens[s + 1 : s + 1 + win][None, :]
|
|
if y.shape[1] < win:
|
|
break
|
|
l = loss_fn(model, x, y)
|
|
mx.eval(l)
|
|
nll_sum += l.item() * y.size
|
|
tok_count += y.size
|
|
val_nll = nll_sum / max(1, tok_count)
|
|
val_ppl = math.exp(val_nll)
|
|
print(f" val nll: {val_nll:.4f} val ppl: {val_ppl:.2f}", flush=True)
|
|
|
|
# Generate samples. We need to make the model use the ternary forward path
|
|
# during generation — it does, since model.__call__ calls TernaryLinear.
|
|
samples = []
|
|
prompts = [
|
|
"Open source software has",
|
|
"World War II was",
|
|
"The development of antibiotics",
|
|
"Sleep is essential for",
|
|
"The scientific method is",
|
|
]
|
|
for p in prompts:
|
|
try:
|
|
txt = generate(model, tokenizer, prompt=p, max_tokens=60, verbose=False)
|
|
except TypeError:
|
|
# Older/newer mlx_lm signatures.
|
|
txt = generate(model, tokenizer, p, max_tokens=60)
|
|
samples.append((p, txt))
|
|
print(f" [{p!r}] -> {txt!r}", flush=True)
|
|
|
|
report = {
|
|
"first_5_loss_avg": sum(losses[:5])/5,
|
|
"last_20_loss_avg": sum(losses[-20:])/20,
|
|
"final_loss": losses[-1],
|
|
"ternary_ok": ternary_ok,
|
|
"ternary_max_err": max_err,
|
|
"val_nll": val_nll,
|
|
"val_ppl": val_ppl,
|
|
"num_steps": NUM_STEPS,
|
|
"batch_size": BATCH_SIZE,
|
|
"seq_len": SEQ_LEN,
|
|
"lr_peak": LR,
|
|
"group_size": GROUP_SIZE,
|
|
"n_ternary_layers": n_ternary,
|
|
"n_ternary_params": int(n_ternary_params),
|
|
"samples": [{"prompt": p, "generation": t} for p, t in samples],
|
|
"loss_curve": losses,
|
|
}
|
|
out_path = Path(__file__).parent / "report.json"
|
|
out_path.write_text(json.dumps(report, indent=2))
|
|
print(f" wrote report to {out_path}", flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|