Files
deep_pro_judge/opus47_1m/ternary_training/ternary_train.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

342 lines
13 KiB
Python

"""
Ternary training for Qwen3-0.6B on train_data.txt.
Approach:
- Load Qwen3-0.6B via mlx_lm (latent weights start as pretrained FP weights).
- Replace every nn.Linear under the transformer blocks (Q/K/V/O, gate/up/down)
with a TernaryLinear module.
- The lm_head is tied to the input embedding for Qwen3-0.6B; we leave the
embedding/lm_head non-ternary (BitNet b1.58 lineage keeps embeddings in
higher precision; ternarizing a 151936x1024 lookup table also doesn't
save FLOPs since it's a gather, not a matmul).
- Group-wise scale s = mean(|W|) over groups of 128 along the input dim.
- Forward uses W_t = clip(round(W / s), -1, 1) * s. STE makes the gradient
of the projection an identity so dL/dW_latent = dL/dW_ternary.
- Train with AdamW for >=200 steps on train_data.txt, then verify, generate,
and compute held-out perplexity.
"""
import json
import math
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_map
from mlx_lm import load
from mlx_lm.generate import generate
import os
GROUP_SIZE = 128
SEQ_LEN = int(os.environ.get("SEQ_LEN", 256))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 4))
NUM_STEPS = int(os.environ.get("NUM_STEPS", 250))
LR = float(os.environ.get("LR", 5e-4))
WARMUP_STEPS = int(os.environ.get("WARMUP_STEPS", 30))
SEED = int(os.environ.get("SEED", 0))
DATA_PATH = Path(__file__).parent / "train_data.txt"
# -------------------------- TernaryLinear --------------------------
class TernaryLinear(nn.Module):
"""Linear layer with weight stored in float (latent) but projected to
a group-wise ternary {-s, 0, +s} representation in the forward pass.
Gradient flows through the projection via a Straight-Through Estimator:
W_eff = W + stop_gradient(W_t - W)
so dL/dW = dL/dW_eff at the W_eff = W_t point.
"""
def __init__(self, in_features: int, out_features: int, group_size: int = GROUP_SIZE):
super().__init__()
if in_features % group_size != 0:
raise ValueError(f"in_features={in_features} not divisible by group_size={group_size}")
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
scale = in_features ** -0.5
self.weight = mx.random.normal((out_features, in_features)) * scale
@staticmethod
def ternarize(W: mx.array, group_size: int) -> mx.array:
out, in_ = W.shape
Wg = W.reshape(out, in_ // group_size, group_size)
# Per-group scale: mean(|W|).
s = mx.mean(mx.abs(Wg), axis=-1, keepdims=True)
s = mx.maximum(s, 1e-8)
q = mx.clip(mx.round(Wg / s), -1, 1)
Wt = q * s
return Wt.reshape(out, in_)
def __call__(self, x: mx.array) -> mx.array:
Wt = self.ternarize(self.weight, self.group_size)
# STE: identity gradient through the projection.
W_eff = self.weight + mx.stop_gradient(Wt - self.weight)
return x @ W_eff.T
def replace_linear_with_ternary(parent: nn.Module):
"""Recursively walk `parent` and swap nn.Linear children for TernaryLinear,
transferring the pretrained weight into the latent slot.
"""
children = parent.children()
for name, child in children.items():
if isinstance(child, nn.Linear) and not isinstance(child, TernaryLinear):
in_f = child.weight.shape[1]
out_f = child.weight.shape[0]
tl = TernaryLinear(in_f, out_f, group_size=GROUP_SIZE)
tl.weight = child.weight
setattr(parent, name, tl)
elif isinstance(child, nn.Module):
replace_linear_with_ternary(child)
elif isinstance(child, list):
for item in child:
if isinstance(item, nn.Module):
replace_linear_with_ternary(item)
def count_ternary_layers(model):
out = collect_ternary_modules(model)
return len(out), sum(tl.weight.size for tl in out)
def collect_ternary_modules(model):
"""Walk the tree and return all TernaryLinear modules."""
found = []
def walk(mod):
for _, child in mod.children().items():
if isinstance(child, TernaryLinear):
found.append(child)
elif isinstance(child, nn.Module):
walk(child)
elif isinstance(child, list):
for item in child:
if isinstance(item, nn.Module):
walk(item)
walk(model)
return found
# -------------------------- Data --------------------------
def load_tokens(tokenizer, path: Path):
text = path.read_text()
# Split by blank lines into "paragraphs" so val/train are coherent units.
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
# Hold out the last 10% as validation.
n_val = max(1, len(paragraphs) // 10)
train_paras = paragraphs[:-n_val]
val_paras = paragraphs[-n_val:]
train_ids = tokenizer.encode("\n\n".join(train_paras))
val_ids = tokenizer.encode("\n\n".join(val_paras))
return mx.array(train_ids, dtype=mx.int32), mx.array(val_ids, dtype=mx.int32)
def sample_batch(tokens: mx.array, seq_len: int, batch_size: int, rng):
n = tokens.shape[0]
max_start = n - seq_len - 1
starts = rng.integers(0, max_start, size=(batch_size,))
x = mx.stack([tokens[s : s + seq_len] for s in starts.tolist()])
y = mx.stack([tokens[s + 1 : s + 1 + seq_len] for s in starts.tolist()])
return x, y
# -------------------------- Loss --------------------------
def loss_fn(model, x, y):
logits = model(x)
logits = logits.astype(mx.float32)
# Cross-entropy over vocab.
log_probs = nn.log_softmax(logits, axis=-1)
# Gather target log-probs.
B, L, V = log_probs.shape
flat = log_probs.reshape(B * L, V)
flat_y = y.reshape(B * L)
nll = -flat[mx.arange(B * L), flat_y]
return nll.mean()
# -------------------------- LR schedule --------------------------
def lr_at(step):
if step < WARMUP_STEPS:
return LR * (step + 1) / WARMUP_STEPS
# Cosine decay to 10% of peak.
progress = (step - WARMUP_STEPS) / max(1, NUM_STEPS - WARMUP_STEPS)
progress = min(1.0, progress)
return LR * (0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)))
# -------------------------- Verification --------------------------
def verify_ternary(model):
"""For every TernaryLinear, recompute the projection and check it is
exactly {-s, 0, +s} per-group.
"""
bad = 0
total = 0
max_err = 0.0
for tl in collect_ternary_modules(model):
W = tl.weight
Wt = TernaryLinear.ternarize(W, tl.group_size)
out, in_ = Wt.shape
Wg = Wt.reshape(out, in_ // tl.group_size, tl.group_size)
s = mx.mean(mx.abs(Wg), axis=-1, keepdims=True) # recompute scale ON the projected tensor
# The recomputed s should equal the s used in the projection for
# values that were already ternary; for verification we compare
# Wt / s_orig to integers in {-1, 0, +1}.
s_orig = mx.maximum(mx.mean(mx.abs(W.reshape(out, in_ // tl.group_size, tl.group_size)), axis=-1, keepdims=True), 1e-8)
q = Wg / s_orig
# Distance to nearest integer in {-1, 0, +1}.
nearest = mx.clip(mx.round(q), -1, 1)
err = mx.max(mx.abs(q - nearest)).item()
max_err = max(max_err, err)
out_size = Wt.size
bad_here = (mx.abs(q - nearest) > 1e-5).sum().item()
bad += bad_here
total += out_size
return bad, total, max_err
# -------------------------- Main --------------------------
def main():
mx.random.seed(SEED)
import numpy as np
rng = np.random.default_rng(SEED)
print("[1/6] Loading Qwen3-0.6B via mlx_lm…", flush=True)
t0 = time.time()
model, tokenizer = load("Qwen/Qwen3-0.6B")
print(f" loaded in {time.time()-t0:.1f}s", flush=True)
print("[2/6] Replacing nn.Linear with TernaryLinear (Q/K/V/O + gate/up/down)…", flush=True)
# Walk into model.model.layers (transformer blocks) only.
for layer in model.model.layers:
replace_linear_with_ternary(layer)
n_ternary, n_ternary_params = count_ternary_layers(model)
print(f" replaced {n_ternary} linear layers ({n_ternary_params/1e6:.1f}M ternary params)", flush=True)
# Force the parameter tree to materialize on device with proper dtypes.
# Latent weights are kept in float32 for stable optimizer math.
def to_f32(p):
if isinstance(p, mx.array) and p.dtype != mx.float32 and p.ndim >= 1:
return p.astype(mx.float32)
return p
# Only cast the ternary latent weights to fp32; leave norms/embeddings alone.
for tl in collect_ternary_modules(model):
tl.weight = tl.weight.astype(mx.float32)
# Quick smoke test: forward pass on 1 token.
test_in = mx.array([[tokenizer.eos_token_id]])
_ = model(test_in)
mx.eval(_)
print(" smoke forward pass ok", flush=True)
print("[3/6] Tokenizing train_data.txt…", flush=True)
train_tokens, val_tokens = load_tokens(tokenizer, DATA_PATH)
print(f" train tokens: {train_tokens.size}, val tokens: {val_tokens.size}", flush=True)
print(f"[4/6] Training for {NUM_STEPS} steps (batch={BATCH_SIZE}, seq_len={SEQ_LEN}, lr_peak={LR})…", flush=True)
optimizer = optim.AdamW(learning_rate=LR, weight_decay=0.0, betas=(0.9, 0.95))
loss_and_grad = nn.value_and_grad(model, loss_fn)
losses = []
t0 = time.time()
for step in range(NUM_STEPS):
x, y = sample_batch(train_tokens, SEQ_LEN, BATCH_SIZE, rng)
optimizer.learning_rate = lr_at(step)
loss, grads = loss_and_grad(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)
l = loss.item()
losses.append(l)
if step % 10 == 0 or step == NUM_STEPS - 1:
elapsed = time.time() - t0
print(f" step {step:4d}/{NUM_STEPS} loss={l:.4f} lr={optimizer.learning_rate.item():.2e} ({elapsed:.0f}s)", flush=True)
print(f" first 5 loss avg: {sum(losses[:5])/5:.4f}", flush=True)
print(f" last 20 loss avg: {sum(losses[-20:])/20:.4f}", flush=True)
print("[5/6] Verifying ternary projection…", flush=True)
bad, total, max_err = verify_ternary(model)
print(f" bad weights: {bad}/{total} max projection error: {max_err:.2e}", flush=True)
ternary_ok = bad == 0 and max_err < 1e-4
print(f" TERNARY OK: {ternary_ok}", flush=True)
print("[6/6] Validation perplexity + samples…", flush=True)
# Compute val perplexity over non-overlapping windows.
n_val = val_tokens.size
win = SEQ_LEN
n_windows = max(1, (n_val - 1) // win)
nll_sum = 0.0
tok_count = 0
for i in range(n_windows):
s = i * win
x = val_tokens[s : s + win][None, :]
y = val_tokens[s + 1 : s + 1 + win][None, :]
if y.shape[1] < win:
break
l = loss_fn(model, x, y)
mx.eval(l)
nll_sum += l.item() * y.size
tok_count += y.size
val_nll = nll_sum / max(1, tok_count)
val_ppl = math.exp(val_nll)
print(f" val nll: {val_nll:.4f} val ppl: {val_ppl:.2f}", flush=True)
# Generate samples. We need to make the model use the ternary forward path
# during generation — it does, since model.__call__ calls TernaryLinear.
samples = []
prompts = [
"Open source software has",
"World War II was",
"The development of antibiotics",
"Sleep is essential for",
"The scientific method is",
]
for p in prompts:
try:
txt = generate(model, tokenizer, prompt=p, max_tokens=60, verbose=False)
except TypeError:
# Older/newer mlx_lm signatures.
txt = generate(model, tokenizer, p, max_tokens=60)
samples.append((p, txt))
print(f" [{p!r}] -> {txt!r}", flush=True)
report = {
"first_5_loss_avg": sum(losses[:5])/5,
"last_20_loss_avg": sum(losses[-20:])/20,
"final_loss": losses[-1],
"ternary_ok": ternary_ok,
"ternary_max_err": max_err,
"val_nll": val_nll,
"val_ppl": val_ppl,
"num_steps": NUM_STEPS,
"batch_size": BATCH_SIZE,
"seq_len": SEQ_LEN,
"lr_peak": LR,
"group_size": GROUP_SIZE,
"n_ternary_layers": n_ternary,
"n_ternary_params": int(n_ternary_params),
"samples": [{"prompt": p, "generation": t} for p, t in samples],
"loss_curve": losses,
}
out_path = Path(__file__).parent / "report.json"
out_path.write_text(json.dumps(report, indent=2))
print(f" wrote report to {out_path}", flush=True)
if __name__ == "__main__":
main()