45c3aad453
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7 - Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training - Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution - Add analysis/ folder with cross-model comparisons and per-challenge deep dives - Add deploy_challenges.sh script - Expand .gitignore to exclude Python envs, ML weights, and build artifacts
596 lines
21 KiB
Python
596 lines
21 KiB
Python
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.optimizers as optim
|
|
from mlx_lm import load
|
|
from mlx_lm.models.qwen3 import Model
|
|
import numpy as np
|
|
from typing import Optional, Tuple, List
|
|
import time
|
|
import json
|
|
|
|
# ==============================================================================
|
|
# Ternary Linear Layer with Straight-Through Estimator (STE)
|
|
# ==============================================================================
|
|
|
|
class TernaryLinear(nn.Module):
|
|
"""
|
|
Ternary linear layer: weights are projected to {-1, 0, +1} * scale
|
|
during forward pass, with STE for backward pass.
|
|
|
|
Group-wise quantization: groups of `group_size` weights share one FP32 scale factor.
|
|
Scale factor: s = mean(|W_group|)
|
|
"""
|
|
def __init__(self, in_features: int, out_features: int, group_size: int = 128):
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.group_size = group_size
|
|
|
|
if in_features % group_size != 0:
|
|
raise ValueError(f"in_features ({in_features}) must be divisible by group_size ({group_size})")
|
|
|
|
self.num_groups = in_features // group_size
|
|
|
|
# Latent weights in float32 (trainable)
|
|
scale = (1.0 / in_features) ** 0.5
|
|
self.weight = mx.random.normal((out_features, in_features), scale=scale)
|
|
|
|
@classmethod
|
|
def from_linear(cls, linear: nn.Linear, group_size: int = 128):
|
|
"""Initialize from an existing Linear layer."""
|
|
in_features = linear.weight.shape[1]
|
|
out_features = linear.weight.shape[0]
|
|
layer = cls(in_features, out_features, group_size)
|
|
# Reinitialize weights randomly for training from scratch
|
|
# rather than copying pretrained weights
|
|
scale = (1.0 / in_features) ** 0.5
|
|
layer.weight = mx.random.normal((out_features, in_features), scale=scale)
|
|
return layer
|
|
|
|
def _quantize(self, weight):
|
|
"""
|
|
Project latent weights to ternary using group-wise scales.
|
|
"""
|
|
# Reshape to (out_features, num_groups, group_size)
|
|
w_reshaped = weight.reshape(self.out_features, self.num_groups, self.group_size)
|
|
|
|
# Compute scale per group: s = mean(|W|)
|
|
scales = mx.mean(mx.abs(w_reshaped), axis=-1, keepdims=True)
|
|
|
|
# Quantize to {-1, 0, +1}
|
|
epsilon = 1e-8
|
|
w_norm = w_reshaped / (scales + epsilon)
|
|
w_quant = mx.clip(mx.round(w_norm), -1, 1)
|
|
|
|
# Dequantize back
|
|
w_ternary = w_quant * scales
|
|
|
|
return w_ternary.reshape(self.out_features, self.in_features), scales
|
|
|
|
def __call__(self, x):
|
|
"""Forward pass with STE."""
|
|
w_ternary, _ = self._quantize(mx.stop_gradient(self.weight))
|
|
|
|
# STE: forward uses ternary, backward uses latent
|
|
w_effective = w_ternary + (self.weight - mx.stop_gradient(self.weight))
|
|
|
|
return x @ w_effective.T
|
|
|
|
def get_ternary_weights(self):
|
|
"""Get the actual ternary-projected weights."""
|
|
w_ternary, scales = self._quantize(self.weight)
|
|
return w_ternary, scales
|
|
|
|
def verify_ternary(self, tol=1e-3):
|
|
"""Verify that weights project cleanly to {-1, 0, +1} * scale."""
|
|
w_ternary, scales = self.get_ternary_weights()
|
|
w_reshaped = w_ternary.reshape(self.out_features, self.num_groups, self.group_size)
|
|
|
|
w_norm = w_reshaped / (scales + 1e-8)
|
|
w_rounded = mx.round(w_norm)
|
|
|
|
is_valid_value = mx.all(
|
|
(mx.abs(w_rounded - (-1.0)) < 1e-3) |
|
|
(mx.abs(w_rounded - 0.0) < 1e-3) |
|
|
(mx.abs(w_rounded - 1.0) < 1e-3)
|
|
)
|
|
|
|
is_ternary = mx.all(mx.abs(w_norm - w_rounded) < tol)
|
|
|
|
return is_ternary.item() and is_valid_value.item()
|
|
|
|
|
|
# ==============================================================================
|
|
# Model Conversion Utilities
|
|
# ==============================================================================
|
|
|
|
def convert_qwen3_to_ternary(model: Model, group_size: int = 128) -> Model:
|
|
"""
|
|
Convert all linear layers in a Qwen3 model to ternary.
|
|
Keeps RMSNorm and embeddings in float.
|
|
"""
|
|
print("Converting model to ternary...")
|
|
|
|
# Skip embedding - it's an Embedding layer, not Linear
|
|
if hasattr(model.model, 'embed_tokens'):
|
|
print(f" Skipping embedding (not Linear): {model.model.embed_tokens.weight.shape}")
|
|
|
|
# Convert each transformer block
|
|
for i, layer in enumerate(model.model.layers):
|
|
print(f"\n Layer {i}:")
|
|
|
|
# Attention projections
|
|
if hasattr(layer, 'self_attn'):
|
|
attn = layer.self_attn
|
|
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
|
|
if hasattr(attn, proj_name):
|
|
proj = getattr(attn, proj_name)
|
|
if isinstance(proj, nn.Linear):
|
|
setattr(attn, proj_name, TernaryLinear.from_linear(proj, group_size))
|
|
print(f" {proj_name}: {proj.weight.shape}")
|
|
|
|
# MLP projections
|
|
if hasattr(layer, 'mlp'):
|
|
mlp = layer.mlp
|
|
for proj_name in ['gate_proj', 'up_proj', 'down_proj']:
|
|
if hasattr(mlp, proj_name):
|
|
proj = getattr(mlp, proj_name)
|
|
if isinstance(proj, nn.Linear):
|
|
setattr(mlp, proj_name, TernaryLinear.from_linear(proj, group_size))
|
|
print(f" {proj_name}: {proj.weight.shape}")
|
|
|
|
# Skip LM head if tied or not Linear
|
|
if hasattr(model, 'lm_head'):
|
|
lm = model.lm_head
|
|
if isinstance(lm, nn.Linear):
|
|
in_features = lm.weight.shape[1]
|
|
if in_features % group_size == 0:
|
|
model.lm_head = TernaryLinear.from_linear(lm, group_size)
|
|
print(f" Converting lm_head: {lm.weight.shape}")
|
|
else:
|
|
print(f" Skipping lm_head (not divisible): {lm.weight.shape}")
|
|
else:
|
|
print(f" Skipping lm_head (not Linear): {type(lm)}")
|
|
|
|
print("\nConversion complete!")
|
|
return model
|
|
|
|
|
|
def count_ternary_layers(model):
|
|
"""Count the number of TernaryLinear layers in the model."""
|
|
count = 0
|
|
def count_module(module):
|
|
nonlocal count
|
|
if isinstance(module, TernaryLinear):
|
|
count += 1
|
|
if hasattr(module, 'items'):
|
|
for _, child in module.items():
|
|
count_module(child)
|
|
elif isinstance(module, list):
|
|
for child in module:
|
|
count_module(child)
|
|
count_module(model)
|
|
return count
|
|
|
|
|
|
# ==============================================================================
|
|
# Verification
|
|
# ==============================================================================
|
|
|
|
def verify_model_ternary(model: Model) -> Tuple[bool, List[str]]:
|
|
"""Verify all TernaryLinear layers produce clean ternary weights."""
|
|
all_pass = True
|
|
failed_layers = []
|
|
|
|
def check_module(module, name=""):
|
|
nonlocal all_pass
|
|
if isinstance(module, TernaryLinear):
|
|
is_ternary = module.verify_ternary()
|
|
if not is_ternary:
|
|
all_pass = False
|
|
failed_layers.append(name)
|
|
print(f" FAIL: {name}")
|
|
else:
|
|
print(f" PASS: {name}")
|
|
|
|
if hasattr(module, 'items'):
|
|
for child_name, child in module.items():
|
|
check_module(child, f"{name}.{child_name}" if name else child_name)
|
|
elif isinstance(module, list):
|
|
for i, child in enumerate(module):
|
|
check_module(child, f"{name}[{i}]" if name else f"[{i}]")
|
|
|
|
check_module(model)
|
|
return all_pass, failed_layers
|
|
|
|
|
|
# ==============================================================================
|
|
# Dataset Utilities
|
|
# ==============================================================================
|
|
|
|
def load_wikitext_data(tokenizer, split="train", max_samples=1000, seq_length=256):
|
|
"""Load WikiText-2 dataset and tokenize."""
|
|
try:
|
|
from datasets import load_dataset
|
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
|
|
except Exception as e:
|
|
print(f"Could not load dataset: {e}")
|
|
print("Using fallback sample text...")
|
|
return create_fallback_data(tokenizer, seq_length)
|
|
|
|
# Tokenize
|
|
all_tokens = []
|
|
for i, example in enumerate(dataset):
|
|
if i >= max_samples:
|
|
break
|
|
text = example["text"].strip()
|
|
if len(text) < 50: # Skip very short lines
|
|
continue
|
|
tokens = tokenizer.encode(text)
|
|
if len(tokens) > 10:
|
|
all_tokens.append(tokens)
|
|
|
|
print(f"Loaded {len(all_tokens)} sequences from WikiText-2 {split}")
|
|
return all_tokens
|
|
|
|
|
|
def create_fallback_data(tokenizer, seq_length=256, num_samples=500):
|
|
"""Create simple fallback training data."""
|
|
sample_texts = [
|
|
"The quick brown fox jumps over the lazy dog. ",
|
|
"In machine learning, neural networks are powerful models. ",
|
|
"The Earth orbits around the Sun in an elliptical path. ",
|
|
"Python is a popular programming language for data science. ",
|
|
"The history of artificial intelligence dates back to the 1950s. ",
|
|
"Deep learning models can process images, text, and speech. ",
|
|
"The capital of France is Paris, known for the Eiffel Tower. ",
|
|
"Water boils at 100 degrees Celsius at standard pressure. ",
|
|
"The human brain contains approximately 86 billion neurons. ",
|
|
"Quantum computing uses quantum bits to perform calculations. ",
|
|
]
|
|
|
|
all_tokens = []
|
|
for i in range(num_samples):
|
|
text = " ".join(sample_texts[i % len(sample_texts)] * 20)
|
|
tokens = tokenizer.encode(text)[:seq_length]
|
|
if len(tokens) > 10:
|
|
all_tokens.append(tokens)
|
|
|
|
print(f"Created {len(all_tokens)} fallback sequences")
|
|
return all_tokens
|
|
|
|
|
|
def create_batches(token_sequences, batch_size=4, seq_length=256):
|
|
"""Create batches of token sequences."""
|
|
batches = []
|
|
current_batch = []
|
|
|
|
for tokens in token_sequences:
|
|
if len(tokens) < 2:
|
|
continue
|
|
# Truncate or pad to seq_length
|
|
if len(tokens) > seq_length:
|
|
tokens = tokens[:seq_length]
|
|
else:
|
|
tokens = tokens + [0] * (seq_length - len(tokens))
|
|
current_batch.append(tokens)
|
|
|
|
if len(current_batch) == batch_size:
|
|
batches.append(mx.array(current_batch))
|
|
current_batch = []
|
|
|
|
if current_batch:
|
|
# Pad last batch
|
|
while len(current_batch) < batch_size:
|
|
current_batch.append([0] * seq_length)
|
|
batches.append(mx.array(current_batch))
|
|
|
|
return batches
|
|
|
|
|
|
# ==============================================================================
|
|
# Training Utilities
|
|
# ==============================================================================
|
|
|
|
def loss_fn(model, inputs, targets):
|
|
"""Compute cross-entropy loss for next-token prediction."""
|
|
logits = model(inputs)
|
|
# logits shape: (batch, seq_len, vocab_size)
|
|
|
|
# Flatten
|
|
logits_flat = logits.reshape(-1, logits.shape[-1])
|
|
targets_flat = targets.reshape(-1)
|
|
|
|
# Cross entropy
|
|
probs = mx.softmax(logits_flat, axis=-1)
|
|
log_probs = mx.log(probs + 1e-10)
|
|
|
|
# Use advanced indexing instead of mx.take
|
|
# log_probs has shape (batch*seq, vocab)
|
|
# targets_flat has shape (batch*seq,)
|
|
# We want log_probs[i, targets_flat[i]] for each i
|
|
batch_seq_len = logits_flat.shape[0]
|
|
indices = mx.arange(batch_seq_len)
|
|
target_log_probs = log_probs[indices, targets_flat]
|
|
nll = -target_log_probs
|
|
|
|
# Mask padding
|
|
mask = targets_flat >= 0
|
|
nll = nll * mask
|
|
|
|
return mx.sum(nll) / mx.sum(mask)
|
|
|
|
|
|
def step(model, inputs, targets, optimizer):
|
|
"""Single training step."""
|
|
loss_and_grad = mx.value_and_grad(loss_fn)
|
|
loss, grads = loss_and_grad(model, inputs, targets)
|
|
|
|
# Update parameters
|
|
optimizer.update(model, grads)
|
|
|
|
return loss
|
|
|
|
|
|
def compute_perplexity(model, tokens_batch):
|
|
"""Compute perplexity on a batch of token sequences."""
|
|
total_loss = 0.0
|
|
total_tokens = 0
|
|
|
|
for tokens in tokens_batch:
|
|
if len(tokens) < 2:
|
|
continue
|
|
inputs = mx.array(tokens[:-1])
|
|
targets = mx.array(tokens[1:])
|
|
|
|
logits = model(inputs[None, :])
|
|
logits_flat = logits.reshape(-1, logits.shape[-1])
|
|
targets_flat = targets.reshape(-1)
|
|
|
|
probs = mx.softmax(logits_flat, axis=-1)
|
|
log_probs = mx.log(probs + 1e-10)
|
|
|
|
# Use advanced indexing
|
|
seq_len = logits_flat.shape[0]
|
|
indices = mx.arange(seq_len)
|
|
target_log_probs = log_probs[indices, targets_flat]
|
|
nll = -target_log_probs
|
|
|
|
total_loss += mx.sum(nll).item()
|
|
total_tokens += len(targets_flat)
|
|
|
|
if total_tokens == 0:
|
|
return float('inf')
|
|
|
|
avg_loss = total_loss / total_tokens
|
|
perplexity = np.exp(avg_loss)
|
|
return perplexity
|
|
|
|
|
|
def generate_text(model, tokenizer, prompt, max_tokens=30, temperature=1.0, top_k=None):
|
|
"""Generate text from prompt using greedy or top-k sampling."""
|
|
tokens = mx.array(tokenizer.encode(prompt))
|
|
|
|
for _ in range(max_tokens):
|
|
logits = model(tokens[None, :])
|
|
next_token_logits = logits[0, -1, :] / temperature
|
|
|
|
if top_k is not None and top_k > 0:
|
|
# Top-k filtering
|
|
top_k_values, top_k_indices = mx.topk(next_token_logits, top_k)
|
|
mask = mx.zeros_like(next_token_logits)
|
|
mask = mask.at[top_k_indices].set(1.0)
|
|
filtered_logits = next_token_logits * mask + (1 - mask) * (-1e10)
|
|
probs = mx.softmax(filtered_logits)
|
|
next_token = mx.argmax(probs)
|
|
else:
|
|
# Greedy
|
|
next_token = mx.argmax(next_token_logits)
|
|
|
|
tokens = mx.concatenate([tokens, next_token[None]])
|
|
|
|
return tokenizer.decode(tokens.tolist())
|
|
|
|
|
|
# ==============================================================================
|
|
# Main Training Script
|
|
# ==============================================================================
|
|
|
|
def main():
|
|
print("=" * 80)
|
|
print("Ternary Bonsai Training - Qwen3-0.6B")
|
|
print("=" * 80)
|
|
|
|
# Hyperparameters
|
|
GROUP_SIZE = 128
|
|
SEQ_LENGTH = 128
|
|
BATCH_SIZE = 2 # Small batch for M4 Mac
|
|
NUM_STEPS = 500
|
|
LEARNING_RATE = 5e-5
|
|
WARMUP_STEPS = 50
|
|
EVAL_EVERY = 50
|
|
GRAD_CLIP = 1.0
|
|
|
|
print(f"\nHyperparameters:")
|
|
print(f" Group size: {GROUP_SIZE}")
|
|
print(f" Sequence length: {SEQ_LENGTH}")
|
|
print(f" Batch size: {BATCH_SIZE}")
|
|
print(f" Training steps: {NUM_STEPS}")
|
|
print(f" Learning rate: {LEARNING_RATE}")
|
|
print(f" Warmup steps: {WARMUP_STEPS}")
|
|
print(f" Grad clip: {GRAD_CLIP}")
|
|
|
|
# Load model
|
|
print("\n[1/6] Loading Qwen3-0.6B...")
|
|
model, tokenizer = load("Qwen/Qwen3-0.6B")
|
|
print(f"Model loaded successfully")
|
|
|
|
# Convert to ternary
|
|
print("\n[2/6] Converting to ternary...")
|
|
model = convert_qwen3_to_ternary(model, group_size=GROUP_SIZE)
|
|
print(f"Converted {count_ternary_layers(model)} linear layers to ternary")
|
|
|
|
# Verify
|
|
print("\n[3/6] Verifying ternary projection...")
|
|
all_pass, failed = verify_model_ternary(model)
|
|
if all_pass:
|
|
print("All layers pass ternary verification!")
|
|
else:
|
|
print(f"Failed layers: {failed}")
|
|
return
|
|
|
|
# Load dataset
|
|
print("\n[4/6] Loading dataset...")
|
|
train_data = load_wikitext_data(tokenizer, split="train", max_samples=2000, seq_length=SEQ_LENGTH)
|
|
val_data = load_wikitext_data(tokenizer, split="validation", max_samples=200, seq_length=SEQ_LENGTH)
|
|
|
|
train_batches = create_batches(train_data, batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH)
|
|
print(f"Created {len(train_batches)} training batches")
|
|
|
|
# Test generation before training
|
|
print("\n[5/6] Testing generation (pre-training)...")
|
|
prompt = "The quick brown fox"
|
|
generated = generate_text(model, tokenizer, prompt, max_tokens=20)
|
|
print(f"Prompt: '{prompt}'")
|
|
print(f"Generated: '{generated}'")
|
|
|
|
# Initialize optimizer
|
|
print("\n[6/6] Starting training...")
|
|
optimizer = optim.AdamW(learning_rate=LEARNING_RATE)
|
|
|
|
# Training loop
|
|
losses = []
|
|
start_time = time.time()
|
|
|
|
def get_lr(step_num):
|
|
"""Learning rate schedule with warmup and cosine decay."""
|
|
if step_num < WARMUP_STEPS:
|
|
return LEARNING_RATE * (step_num + 1) / WARMUP_STEPS
|
|
else:
|
|
progress = (step_num - WARMUP_STEPS) / (NUM_STEPS - WARMUP_STEPS)
|
|
return LEARNING_RATE * 0.5 * (1 + np.cos(np.pi * progress))
|
|
|
|
for step_num in range(NUM_STEPS):
|
|
# Update learning rate
|
|
current_lr = get_lr(step_num)
|
|
optimizer.learning_rate = current_lr
|
|
|
|
# Get batch
|
|
batch_idx = step_num % len(train_batches)
|
|
batch = train_batches[batch_idx]
|
|
|
|
inputs = batch[:, :-1]
|
|
targets = batch[:, 1:]
|
|
|
|
# Training step with gradient clipping
|
|
loss_and_grad = mx.value_and_grad(loss_fn)
|
|
loss, grads = loss_and_grad(model, inputs, targets)
|
|
|
|
# Gradient clipping
|
|
if GRAD_CLIP > 0:
|
|
def clip_grads(g):
|
|
if isinstance(g, dict):
|
|
return {k: clip_grads(v) for k, v in g.items()}
|
|
elif isinstance(g, list):
|
|
return [clip_grads(v) for v in g]
|
|
else:
|
|
return mx.clip(g, -GRAD_CLIP, GRAD_CLIP)
|
|
grads = clip_grads(grads)
|
|
|
|
optimizer.update(model, grads)
|
|
mx.eval(loss)
|
|
|
|
losses.append(loss.item())
|
|
|
|
# Logging
|
|
if (step_num + 1) % 10 == 0:
|
|
avg_loss = np.mean(losses[-10:])
|
|
print(f"Step {step_num + 1}/{NUM_STEPS} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | Time: {time.time() - start_time:.1f}s")
|
|
|
|
# Evaluation
|
|
if (step_num + 1) % EVAL_EVERY == 0:
|
|
print(f"\n--- Evaluation at step {step_num + 1} ---")
|
|
|
|
# Generate sample
|
|
prompt = "Artificial intelligence is"
|
|
generated = generate_text(model, tokenizer, prompt, max_tokens=30, temperature=0.8)
|
|
print(f"Prompt: '{prompt}'")
|
|
print(f"Generated: '{generated}'")
|
|
|
|
# Compute perplexity on small validation set
|
|
if val_data:
|
|
ppl = compute_perplexity(model, val_data[:20])
|
|
print(f"Perplexity: {ppl:.2f}")
|
|
|
|
# Verify ternary
|
|
all_pass, _ = verify_model_ternary(model)
|
|
print(f"Ternary verification: {'PASS' if all_pass else 'FAIL'}")
|
|
print("-" * 40 + "\n")
|
|
|
|
# Final evaluation
|
|
print("\n" + "=" * 80)
|
|
print("FINAL EVALUATION")
|
|
print("=" * 80)
|
|
|
|
# Loss curve
|
|
print(f"\nInitial loss: {losses[0]:.4f}")
|
|
print(f"Final loss: {losses[-1]:.4f}")
|
|
print(f"Loss decrease: {losses[0] - losses[-1]:.4f}")
|
|
|
|
# Generate multiple samples
|
|
prompts = [
|
|
"The capital of France is",
|
|
"Machine learning is a type of",
|
|
"In 1492, Christopher Columbus",
|
|
]
|
|
|
|
print("\n--- Generation Samples ---")
|
|
for prompt in prompts:
|
|
generated = generate_text(model, tokenizer, prompt, max_tokens=30, temperature=0.8)
|
|
print(f"Prompt: '{prompt}'")
|
|
print(f"Generated: '{generated}'")
|
|
print()
|
|
|
|
# Perplexity
|
|
if val_data:
|
|
ppl = compute_perplexity(model, val_data[:50])
|
|
print(f"Final perplexity: {ppl:.2f}")
|
|
|
|
# Verify ternary one final time
|
|
print("\n--- Ternary Verification ---")
|
|
all_pass, failed = verify_model_ternary(model)
|
|
print(f"All layers ternary: {all_pass}")
|
|
if failed:
|
|
print(f"Failed: {failed}")
|
|
|
|
# Save results
|
|
results = {
|
|
"hyperparameters": {
|
|
"group_size": GROUP_SIZE,
|
|
"seq_length": SEQ_LENGTH,
|
|
"batch_size": BATCH_SIZE,
|
|
"num_steps": NUM_STEPS,
|
|
"learning_rate": LEARNING_RATE,
|
|
},
|
|
"training": {
|
|
"initial_loss": float(losses[0]),
|
|
"final_loss": float(losses[-1]),
|
|
"loss_curve": [float(l) for l in losses],
|
|
},
|
|
"verification": {
|
|
"all_ternary": all_pass,
|
|
"failed_layers": failed,
|
|
},
|
|
"perplexity": float(ppl) if val_data else None,
|
|
}
|
|
|
|
with open("training_results.json", "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print("\nResults saved to training_results.json")
|
|
print("=" * 80)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|