Files
deep_pro_judge/glm5/ternary_training/run_rerun.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

659 lines
25 KiB
Python

#!/usr/bin/env python3
"""
Ternary Bonsai Re-Run: Train on train_data.txt
================================================
Same architecture/hyperparameters as the original run, but using the
provided train_data.txt file instead of WikiText-2.
"""
import argparse
import math
import os
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.optimizers import AdamW
# =============================================================================
# TERNARY LINEAR LAYER (identical to original)
# =============================================================================
GROUP_SIZE = 128
@mx.custom_function
def ternary_projection(w):
original_shape = w.shape
w_2d = w.reshape(-1, w.shape[-1])
in_features = w_2d.shape[-1]
pad_size = (GROUP_SIZE - (in_features % GROUP_SIZE)) % GROUP_SIZE
if pad_size > 0:
w_2d = mx.pad(w_2d, [(0, 0), (0, pad_size)], constant_values=0.0)
padded_features = w_2d.shape[-1]
num_groups = padded_features // GROUP_SIZE
w_grouped = w_2d.reshape(w_2d.shape[0], num_groups, GROUP_SIZE)
scales = mx.mean(mx.abs(w_grouped), axis=-1, keepdims=True)
scales = mx.where(scales < 1e-8, mx.ones_like(scales), scales)
ternary = mx.clip(mx.round(w_grouped / scales), -1.0, 1.0)
result_grouped = ternary * scales
result_2d = result_grouped.reshape(w_2d.shape[0], padded_features)
if pad_size > 0:
result_2d = result_2d[:, :in_features]
return result_2d.reshape(original_shape)
@ternary_projection.vjp
def ternary_projection_vjp(primals, cotangent, output):
return (cotangent,)
class TernaryLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = mx.random.normal(shape=(out_features, in_features)) * (in_features ** (-0.5))
self.bias = mx.zeros((out_features,)) if bias else None
def __call__(self, x):
w = ternary_projection(self.weight)
out = x @ w.T
if self.bias is not None:
out = out + self.bias
return out
class TernaryEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = mx.random.normal(shape=(num_embeddings, embedding_dim)) * (embedding_dim ** (-0.5))
def __call__(self, x):
w = ternary_projection(self.weight)
return w[x]
def as_linear(self, x):
w = ternary_projection(self.weight)
return x @ w.T
# =============================================================================
# TERNARY QWEN3 MODEL (identical to original)
# =============================================================================
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
from mlx_lm.models.activations import swiglu
from mlx_lm.models.rope_utils import initialize_rope
@dataclass
class ModelArgs:
model_type: str = "qwen3"
hidden_size: int = 1024
num_hidden_layers: int = 28
intermediate_size: int = 3072
num_attention_heads: int = 16
rms_norm_eps: float = 1e-6
vocab_size: int = 151936
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rope_theta: float = 1000000.0
head_dim: int = 128
tie_word_embeddings: bool = True
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
class TernaryAttention(nn.Module):
def __init__(self, args):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads
head_dim = args.head_dim
self.scale = head_dim ** -0.5
self.q_proj = TernaryLinear(dim, self.n_heads * head_dim)
self.k_proj = TernaryLinear(dim, self.n_kv_heads * head_dim)
self.v_proj = TernaryLinear(dim, self.n_kv_heads * head_dim)
self.o_proj = TernaryLinear(self.n_heads * head_dim, dim)
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = initialize_rope(
head_dim, base=args.rope_theta, traditional=False,
scaling_config=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
)
def __call__(self, x, mask=None, cache=None):
B, L, D = x.shape
queries = self.q_proj(x)
keys = self.k_proj(x)
values = self.v_proj(x)
queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class TernaryMLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = TernaryLinear(dim, hidden_dim)
self.down_proj = TernaryLinear(hidden_dim, dim)
self.up_proj = TernaryLinear(dim, hidden_dim)
def __call__(self, x):
return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x)))
class TernaryTransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = TernaryAttention(args)
self.mlp = TernaryMLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, x, mask=None, cache=None):
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
class TernaryQwen3Model(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = TernaryEmbedding(args.vocab_size, args.hidden_size)
self.layers = [TernaryTransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, inputs, cache=None, input_embeddings=None):
h = input_embeddings if input_embeddings is not None else self.embed_tokens(inputs)
if cache is None:
cache = [None] * len(self.layers)
mask = create_attention_mask(h, cache[0])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class TernaryModel(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = TernaryQwen3Model(args)
if not args.tie_word_embeddings:
self.lm_head = TernaryLinear(args.hidden_size, args.vocab_size)
def __call__(self, inputs, cache=None, input_embeddings=None):
out = self.model(inputs, cache, input_embeddings)
if self.args.tie_word_embeddings:
return self.model.embed_tokens.as_linear(out)
else:
return self.lm_head(out)
@property
def layers(self):
return self.model.layers
# =============================================================================
# WEIGHT CONVERSION (identical to original)
# =============================================================================
def convert_weights(src_model, dst_model):
src_m = src_model.model if hasattr(src_model, 'model') else src_model
src_weights = {}
def collect_src(module, prefix=''):
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, nn.Linear):
src_weights[f'{full}.weight'] = obj.weight
try:
if obj.bias is not None:
src_weights[f'{full}.bias'] = obj.bias
except AttributeError:
pass
elif isinstance(obj, nn.Embedding):
src_weights[f'{full}.weight'] = obj.weight
elif isinstance(obj, nn.RMSNorm):
src_weights[f'{full}.weight'] = obj.weight
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
collect_src(item, f'{full}.{i}.')
elif isinstance(obj, nn.Module):
collect_src(obj, f'{full}.')
collect_src(src_m, 'model.')
def set_dst(module, prefix=''):
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, TernaryLinear):
key = f'{full}.weight'
if key in src_weights:
obj.weight = src_weights[key].astype(mx.float32)
elif isinstance(obj, TernaryEmbedding):
key = f'{full}.weight'
if key in src_weights:
obj.weight = src_weights[key].astype(mx.float32)
elif isinstance(obj, nn.RMSNorm):
key = f'{full}.weight'
if key in src_weights:
obj.weight = src_weights[key].astype(mx.float16)
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
set_dst(item, f'{full}.{i}.')
elif isinstance(obj, nn.Module):
set_dst(obj, f'{full}.')
set_dst(dst_model, '')
# =============================================================================
# VERIFICATION (identical to original)
# =============================================================================
def verify_ternary(model):
results = {}
all_ok = True
def check(module, prefix=''):
nonlocal all_ok
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, (TernaryLinear, TernaryEmbedding)):
w = obj.weight
w_flat = w.reshape(-1, w.shape[-1])
in_feat = w_flat.shape[-1]
pad = (GROUP_SIZE - (in_feat % GROUP_SIZE)) % GROUP_SIZE
if pad > 0:
w_flat_pad = mx.pad(w_flat, [(0, 0), (0, pad)], constant_values=0.0)
else:
w_flat_pad = w_flat
n_groups = w_flat_pad.shape[-1] // GROUP_SIZE
w_grp = w_flat_pad.reshape(w_flat_pad.shape[0], n_groups, GROUP_SIZE)
scales = mx.mean(mx.abs(w_grp), axis=-1, keepdims=True)
scales = mx.where(scales < 1e-8, mx.ones_like(scales), scales)
norm_vals = mx.clip(mx.round(w_grp / scales), -1.0, 1.0)
norm_2d = norm_vals.reshape(w_flat_pad.shape[0], -1)
if pad > 0:
norm_2d = norm_2d[:, :in_feat]
norm_flat = norm_2d.reshape(-1)
n_neg = int(mx.sum(norm_flat == -1))
n_zero = int(mx.sum(norm_flat == 0))
n_pos = int(mx.sum(norm_flat == 1))
total = int(norm_flat.size)
is_ternary = bool(mx.all((norm_flat == -1) | (norm_flat == 0) | (norm_flat == 1)))
results[full] = {
'is_ternary': is_ternary,
'shape': tuple(w.shape),
'distribution': {-1: n_neg/total, 0: n_zero/total, 1: n_pos/total},
}
if not is_ternary:
all_ok = False
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
check(item, f'{full}.{i}.')
elif isinstance(obj, nn.Module):
check(obj, f'{full}.')
check(model, '')
return all_ok, results
def generate_text(model, tokenizer, prompt, max_tokens=80, temp=0.8):
tokens = tokenizer.encode(prompt)
for _ in range(max_tokens):
input_tokens = tokens[-512:] if len(tokens) > 512 else tokens
input_ids = mx.array([input_tokens])
logits = model(input_ids)
last_logits = logits[:, -1, :] / max(temp, 0.01)
next_token = mx.random.categorical(last_logits, axis=-1)
tokens.append(int(next_token[0]))
return tokenizer.decode(tokens)
def collect_all_params(module, prefix=''):
params = {}
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, (TernaryLinear, TernaryEmbedding)):
params[f'{full}.weight'] = obj.weight
elif isinstance(obj, nn.RMSNorm):
params[f'{full}.weight'] = obj.weight
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
params.update(collect_all_params(item, f'{full}.{i}.'))
elif isinstance(obj, nn.Module):
params.update(collect_all_params(obj, f'{full}.'))
return params
# =============================================================================
# TRAINING (modified to use train_data.txt)
# =============================================================================
class LRSchedule:
def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-5):
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
def __call__(self, step):
if step < self.warmup_steps:
return self.base_lr * (step + 1) / self.warmup_steps
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
def clip_grad_norm(grads, max_norm=1.0):
total_norm_sq = mx.array(0.0)
flat = nn.utils.tree_flatten(grads)
for _, g in flat:
if isinstance(g, mx.array) and g.ndim >= 1:
total_norm_sq = total_norm_sq + mx.sum(g ** 2)
total_norm = mx.sqrt(total_norm_sq)
scale = mx.where(total_norm > max_norm, max_norm / (total_norm + 1e-6), mx.array(1.0))
clipped = nn.utils.tree_map(lambda g: g * scale if isinstance(g, mx.array) and g.ndim >= 1 else g, grads)
return clipped, float(total_norm)
def main():
parser = argparse.ArgumentParser(description="Ternary Bonsai Rerun on train_data.txt")
parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B")
parser.add_argument("--data-path", default=os.path.join(os.path.dirname(__file__), "train_data.txt"))
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--seq-len", type=int, default=256)
parser.add_argument("--steps", type=int, default=250)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--min-lr", type=float, default=5e-6)
parser.add_argument("--warmup", type=int, default=25)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--eval-every", type=int, default=50)
parser.add_argument("--save-path", default="./ternary_trained_rerun")
args = parser.parse_args()
print("=" * 70)
print("TERNARY BONSAI RE-RUN: train_data.txt")
print("=" * 70)
print(f"Model: {args.model_name}")
print(f"Data: {args.data_path}")
print(f"Steps: {args.steps}, Batch size: {args.batch_size}, Seq len: {args.seq_len}")
print(f"LR: {args.lr}, Warmup: {args.warmup}, Weight decay: {args.weight_decay}")
print()
# Step 1: Load and convert model
print("[1/5] Loading Qwen3-0.6B...")
from mlx_lm import load
src_model, tokenizer = load(args.model_name)
src_args = src_model.args
config = ModelArgs(
model_type=src_args.model_type,
hidden_size=src_args.hidden_size,
num_hidden_layers=src_args.num_hidden_layers,
intermediate_size=src_args.intermediate_size,
num_attention_heads=src_args.num_attention_heads,
rms_norm_eps=src_args.rms_norm_eps,
vocab_size=src_args.vocab_size,
num_key_value_heads=src_args.num_key_value_heads,
max_position_embeddings=src_args.max_position_embeddings,
rope_theta=src_args.rope_theta,
head_dim=src_args.head_dim,
tie_word_embeddings=src_args.tie_word_embeddings,
rope_scaling=src_args.rope_scaling,
)
print("\n[2/5] Creating ternary model and copying weights...")
model = TernaryModel(config)
convert_weights(src_model, model)
del src_model
mx.clear_cache()
# Verify ternary projection before training
print("\n[3/5] Pre-training ternary check...")
all_ok, results = verify_ternary(model)
print(f" All weights ternary: {all_ok}")
if all_ok:
for name, r in list(results.items())[:3]:
d = r['distribution']
print(f" {name}: shape={r['shape']}, "
f"-1:{d[-1]:.3f}, 0:{d[0]:.3f}, +1:{d[1]:.3f}")
# Load training data from train_data.txt
print(f"\n[4/5] Loading train_data.txt from {args.data_path}...")
with open(args.data_path, 'r') as f:
text = f.read()
print(f" Text length: {len(text):,} characters")
# Tokenize - use 90% for train, 10% for validation
all_tokens = tokenizer.encode(text)
print(f" Total tokens: {len(all_tokens):,}")
split_point = int(0.9 * len(all_tokens))
train_tokens = all_tokens[:split_point]
val_tokens = all_tokens[split_point:]
print(f" Train tokens: {len(train_tokens):,}")
print(f" Val tokens: {len(val_tokens):,}")
seq_len = args.seq_len
train_sequences = []
for i in range(0, len(train_tokens) - seq_len - 1, seq_len + 1):
train_sequences.append(train_tokens[i:i + seq_len + 1])
val_sequences = []
for i in range(0, len(val_tokens) - seq_len - 1, seq_len + 1):
val_sequences.append(val_tokens[i:i + seq_len + 1])
n_train = len(train_sequences)
n_val = len(val_sequences)
print(f" Train sequences: {n_train:,} (seq_len={seq_len})")
print(f" Val sequences: {n_val:,}")
if n_train == 0:
print("ERROR: No training sequences! Data is too short for seq_len={seq_len}")
return
# Training loop
print(f"\n[5/5] Training for {args.steps} steps...\n")
lr_schedule = LRSchedule(args.lr, args.warmup, args.steps, args.min_lr)
optimizer = AdamW(learning_rate=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95))
def loss_fn(model, batch):
input_ids = mx.array(batch[:, :-1])
targets = mx.array(batch[:, 1:])
logits = model(input_ids)
return nn.losses.cross_entropy(logits, targets, reduction="mean")
step = 0
losses = []
start_time = time.time()
for epoch in range(100):
if step >= args.steps:
break
indices = np.random.permutation(n_train)
for i in range(0, n_train, args.batch_size):
if step >= args.steps:
break
batch_indices = indices[i:i + args.batch_size]
if len(batch_indices) < args.batch_size:
# Pad by wrapping around
extra = np.random.choice(n_train, size=args.batch_size - len(batch_indices), replace=False)
batch_indices = np.concatenate([batch_indices, extra])
batch = np.array([train_sequences[j] for j in batch_indices])
current_lr = lr_schedule(step)
optimizer.learning_rate = current_lr
loss, grads = nn.value_and_grad(model, lambda m: loss_fn(m, batch))(model)
grads, grad_norm = clip_grad_norm(grads, max_norm=1.0)
optimizer.update(model, grads)
mx.eval(loss)
losses.append(float(loss))
step += 1
if step % args.log_every == 0:
recent = losses[-args.log_every:]
avg_loss = np.mean(recent)
elapsed = time.time() - start_time
toks_per_sec = args.log_every * args.batch_size * seq_len / max(elapsed, 0.001)
print(f" Step {step:4d}/{args.steps} | Loss: {avg_loss:.4f} | "
f"GradNorm: {grad_norm:.1f} | LR: {current_lr:.2e} | Tok/s: {toks_per_sec:.0f}")
start_time = time.time()
if step % args.eval_every == 0 and step > 0:
val_indices = np.random.choice(n_val, size=min(args.batch_size, n_val), replace=False)
val_batch = np.array([val_sequences[j] for j in val_indices])
val_loss = loss_fn(model, val_batch)
mx.eval(val_loss)
val_ppl = math.exp(min(float(val_loss), 20))
print(f" >> Eval at step {step}: val_loss={float(val_loss):.4f}, val_ppl={val_ppl:.1f}")
all_ok, _ = verify_ternary(model)
print(f" Ternary check: {'PASS' if all_ok else 'FAIL'}")
# Final evaluation
print("\n" + "=" * 70)
print("FINAL EVALUATION")
print("=" * 70)
all_ok, results = verify_ternary(model)
print(f"\n1. TERNARY VERIFICATION: {'PASS' if all_ok else 'FAIL'}")
for name, r in sorted(results.items())[:10]:
d = r['distribution']
status = "OK" if r['is_ternary'] else "FAIL"
print(f" [{status}] {name}: shape={r['shape']}, "
f"-1:{d[-1]:.3f}, 0:{d[0]:.3f}, +1:{d[1]:.3f}")
if len(results) > 10:
print(f" ... ({len(results) - 10} more layers)")
# Validation perplexity
print("\n2. PERPLEXITY EVALUATION:")
eval_batch_size = min(4, n_val)
val_losses_list = []
for i in range(0, min(n_val, 20), eval_batch_size):
batch = np.array(val_sequences[i:i + eval_batch_size])
if len(batch) < eval_batch_size:
continue
vl = loss_fn(model, batch)
mx.eval(vl)
val_losses_list.append(float(vl))
avg_val_loss = np.mean(val_losses_list) if val_losses_list else float('inf')
vocab_size = config.vocab_size
random_loss = math.log(vocab_size)
print(f" Train loss (last 50): {np.mean(losses[-50:]):.4f}")
print(f" Val loss: {avg_val_loss:.4f}")
print(f" Val perplexity: {math.exp(min(avg_val_loss, 20)):.1f}")
print(f" Random baseline: perplexity={vocab_size} (loss={random_loss:.2f})")
# Text generation
print("\n3. TEXT GENERATION:")
prompts = [
"The history of the United States",
"Open source software",
"The development of computers",
"World War II",
"The philosophy of mind",
]
for prompt in prompts:
try:
generated = generate_text(model, tokenizer, prompt, max_tokens=80, temp=0.7)
print(f" Prompt: {prompt}")
print(f" Output: {generated[:250]}")
print()
except Exception as e:
print(f" Generation failed for '{prompt}': {e}")
# Save
if args.save_path:
os.makedirs(args.save_path, exist_ok=True)
print(f"\nSaving model to {args.save_path}...")
params = collect_all_params(model)
if params:
mx.save_safetensors(
os.path.join(args.save_path, "weights.safetensors"),
params
)
print(f"Saved {len(params)} weight tensors.")
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"Ternary projection verified: {all_ok}")
print(f"Final training loss: {np.mean(losses[-50:]):.4f}")
print(f"Validation perplexity: {math.exp(min(avg_val_loss, 20)):.1f}")
print(f"(Random baseline: {vocab_size})")
print()
print("Comparison with previous WikiText-2 run:")
print(f" Previous: val_ppl=340.9, train_loss=6.13 (WikiText-2, 250 steps)")
print(f" This run: val_ppl={math.exp(min(avg_val_loss, 20)):.1f}, train_loss={np.mean(losses[-50:]):.4f} (train_data.txt, {args.steps} steps)")
if __name__ == "__main__":
main()