Files
deep_pro_judge/glm5/ternary_training/run_ternary.py
T
sleepy 45c3aad453 feat: expand to 6 models, 8 challenges; rewrite README with DeepSeek V4 Pro analysis
- Add Claude Opus 4.7, Kimi K2.6, GLM-5.1 to existing GLM-5, Qwen3-6, MiniMax-M2.7
- Add 5 new challenges: flash attention fwd/bwd, beam search, DFlash, ternary training
- Rewrite README with TL;DR rankings, grade matrix, and DeepSeek V4 Pro attribution
- Add analysis/ folder with cross-model comparisons and per-challenge deep dives
- Add deploy_challenges.sh script
- Expand .gitignore to exclude Python envs, ML weights, and build artifacts
2026-04-27 18:49:22 +02:00

674 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Ternary Bonsai Training Script
===============================
Self-contained script that:
1. Loads Qwen3-0.6B
2. Converts to ternary model (TernaryLinear layers with group-wise quantization)
3. Fine-tunes on WikiText-2 for 200+ steps using STE
4. Verifies ternary projection, generates text, measures perplexity
5. Reports findings
Architecture: Qwen3 with ALL linear layers ternary {-1, 0, +1} × group scale
Group size: 128, Scale: mean(|W_group|), STE: gradient passthrough
"""
import argparse
import math
import os
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.optimizers import AdamW
# =============================================================================
# TERNARY LINEAR LAYER
# =============================================================================
GROUP_SIZE = 128
@mx.custom_function
def ternary_projection(w):
original_shape = w.shape
w_2d = w.reshape(-1, w.shape[-1])
in_features = w_2d.shape[-1]
pad_size = (GROUP_SIZE - (in_features % GROUP_SIZE)) % GROUP_SIZE
if pad_size > 0:
w_2d = mx.pad(w_2d, [(0, 0), (0, pad_size)], constant_values=0.0)
padded_features = w_2d.shape[-1]
num_groups = padded_features // GROUP_SIZE
w_grouped = w_2d.reshape(w_2d.shape[0], num_groups, GROUP_SIZE)
# s = mean(|W_group|)
scales = mx.mean(mx.abs(w_grouped), axis=-1, keepdims=True)
scales = mx.where(scales < 1e-8, mx.ones_like(scales), scales)
# Round to ternary: {-1, 0, +1}
ternary = mx.clip(mx.round(w_grouped / scales), -1.0, 1.0)
result_grouped = ternary * scales
result_2d = result_grouped.reshape(w_2d.shape[0], padded_features)
if pad_size > 0:
result_2d = result_2d[:, :in_features]
return result_2d.reshape(original_shape)
@ternary_projection.vjp
def ternary_projection_vjp(primals, cotangent, output):
return (cotangent,)
class TernaryLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# BitNet-style init: normal scaled by fan_in^(-0.5)
self.weight = mx.random.normal(shape=(out_features, in_features)) * (in_features ** (-0.5))
self.bias = mx.zeros((out_features,)) if bias else None
def __call__(self, x):
w = ternary_projection(self.weight)
out = x @ w.T
if self.bias is not None:
out = out + self.bias
return out
class TernaryEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = mx.random.normal(shape=(num_embeddings, embedding_dim)) * (embedding_dim ** (-0.5))
def __call__(self, x):
w = ternary_projection(self.weight)
return w[x]
def as_linear(self, x):
w = ternary_projection(self.weight)
return x @ w.T
# =============================================================================
# TERNARY QWEN3 MODEL
# =============================================================================
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
from mlx_lm.models.activations import swiglu
from mlx_lm.models.rope_utils import initialize_rope
@dataclass
class ModelArgs:
model_type: str = "qwen3"
hidden_size: int = 1024
num_hidden_layers: int = 28
intermediate_size: int = 3072
num_attention_heads: int = 16
rms_norm_eps: float = 1e-6
vocab_size: int = 151936
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rope_theta: float = 1000000.0
head_dim: int = 128
tie_word_embeddings: bool = True
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
class TernaryAttention(nn.Module):
def __init__(self, args):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads
head_dim = args.head_dim
self.scale = head_dim ** -0.5
self.q_proj = TernaryLinear(dim, self.n_heads * head_dim)
self.k_proj = TernaryLinear(dim, self.n_kv_heads * head_dim)
self.v_proj = TernaryLinear(dim, self.n_kv_heads * head_dim)
self.o_proj = TernaryLinear(self.n_heads * head_dim, dim)
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = initialize_rope(
head_dim, base=args.rope_theta, traditional=False,
scaling_config=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
)
def __call__(self, x, mask=None, cache=None):
B, L, D = x.shape
queries = self.q_proj(x)
keys = self.k_proj(x)
values = self.v_proj(x)
queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class TernaryMLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = TernaryLinear(dim, hidden_dim)
self.down_proj = TernaryLinear(hidden_dim, dim)
self.up_proj = TernaryLinear(dim, hidden_dim)
def __call__(self, x):
return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x)))
class TernaryTransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = TernaryAttention(args)
self.mlp = TernaryMLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, x, mask=None, cache=None):
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
class TernaryQwen3Model(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = TernaryEmbedding(args.vocab_size, args.hidden_size)
self.layers = [TernaryTransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, inputs, cache=None, input_embeddings=None):
h = input_embeddings if input_embeddings is not None else self.embed_tokens(inputs)
if cache is None:
cache = [None] * len(self.layers)
mask = create_attention_mask(h, cache[0])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class TernaryModel(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = TernaryQwen3Model(args)
if not args.tie_word_embeddings:
self.lm_head = TernaryLinear(args.hidden_size, args.vocab_size)
def __call__(self, inputs, cache=None, input_embeddings=None):
out = self.model(inputs, cache, input_embeddings)
if self.args.tie_word_embeddings:
return self.model.embed_tokens.as_linear(out)
else:
return self.lm_head(out)
@property
def layers(self):
return self.model.layers
# =============================================================================
# WEIGHT CONVERSION
# =============================================================================
def convert_weights(src_model, dst_model):
"""Copy weights from original Qwen3 to ternary model."""
src_m = src_model.model if hasattr(src_model, 'model') else src_model
src_weights = {}
def collect_src(module, prefix=''):
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, nn.Linear):
src_weights[f'{full}.weight'] = obj.weight
try:
if obj.bias is not None:
src_weights[f'{full}.bias'] = obj.bias
except AttributeError:
pass
elif isinstance(obj, nn.Embedding):
src_weights[f'{full}.weight'] = obj.weight
elif isinstance(obj, nn.RMSNorm):
src_weights[f'{full}.weight'] = obj.weight
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
collect_src(item, f'{full}.{i}.')
elif isinstance(obj, nn.Module):
collect_src(obj, f'{full}.')
collect_src(src_m, 'model.')
def set_dst(module, prefix=''):
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, TernaryLinear):
key = f'{full}.weight'
if key in src_weights:
obj.weight = src_weights[key].astype(mx.float32)
elif isinstance(obj, TernaryEmbedding):
key = f'{full}.weight'
if key in src_weights:
obj.weight = src_weights[key].astype(mx.float32)
elif isinstance(obj, nn.RMSNorm):
key = f'{full}.weight'
if key in src_weights:
obj.weight = src_weights[key].astype(mx.float16)
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
set_dst(item, f'{full}.{i}.')
elif isinstance(obj, nn.Module):
set_dst(obj, f'{full}.')
set_dst(dst_model, '')
# =============================================================================
# VERIFICATION
# =============================================================================
def verify_ternary(model):
"""Check all ternary layers project to {-1, 0, +1} correctly."""
results = {}
all_ok = True
def check(module, prefix=''):
nonlocal all_ok
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, (TernaryLinear, TernaryEmbedding)):
w = obj.weight
w_flat = w.reshape(-1, w.shape[-1])
in_feat = w_flat.shape[-1]
pad = (GROUP_SIZE - (in_feat % GROUP_SIZE)) % GROUP_SIZE
if pad > 0:
w_flat_pad = mx.pad(w_flat, [(0, 0), (0, pad)], constant_values=0.0)
else:
w_flat_pad = w_flat
n_groups = w_flat_pad.shape[-1] // GROUP_SIZE
w_grp = w_flat_pad.reshape(w_flat_pad.shape[0], n_groups, GROUP_SIZE)
scales = mx.mean(mx.abs(w_grp), axis=-1, keepdims=True)
scales = mx.where(scales < 1e-8, mx.ones_like(scales), scales)
norm_vals = mx.clip(mx.round(w_grp / scales), -1.0, 1.0)
norm_2d = norm_vals.reshape(w_flat_pad.shape[0], -1)
if pad > 0:
norm_2d = norm_2d[:, :in_feat]
norm_flat = norm_2d.reshape(-1)
n_neg = int(mx.sum(norm_flat == -1))
n_zero = int(mx.sum(norm_flat == 0))
n_pos = int(mx.sum(norm_flat == 1))
total = int(norm_flat.size)
is_ternary = bool(mx.all((norm_flat == -1) | (norm_flat == 0) | (norm_flat == 1)))
results[full] = {
'is_ternary': is_ternary,
'shape': tuple(w.shape),
'distribution': {-1: n_neg/total, 0: n_zero/total, 1: n_pos/total},
}
if not is_ternary:
all_ok = False
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
check(item, f'{full}.{i}.')
elif isinstance(obj, nn.Module):
check(obj, f'{full}.')
check(model, '')
return all_ok, results
def generate_text(model, tokenizer, prompt, max_tokens=80, temp=0.8):
"""Generate text from the model."""
tokens = tokenizer.encode(prompt)
for _ in range(max_tokens):
input_tokens = tokens[-512:] if len(tokens) > 512 else tokens
input_ids = mx.array([input_tokens])
logits = model(input_ids)
last_logits = logits[:, -1, :] / max(temp, 0.01)
next_token = mx.random.categorical(last_logits, axis=-1)
tokens.append(int(next_token[0]))
return tokenizer.decode(tokens)
def collect_all_params(module, prefix=''):
"""Recursively collect all parameters from model."""
params = {}
for name in module:
obj = module[name]
full = f'{prefix}{name}'
if isinstance(obj, (TernaryLinear, TernaryEmbedding)):
params[f'{full}.weight'] = obj.weight
elif isinstance(obj, nn.RMSNorm):
params[f'{full}.weight'] = obj.weight
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
if isinstance(item, nn.Module):
params.update(collect_all_params(item, f'{full}.{i}.'))
elif isinstance(obj, nn.Module):
params.update(collect_all_params(obj, f'{full}.'))
return params
# =============================================================================
# TRAINING
# =============================================================================
class LRSchedule:
def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-5):
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
def __call__(self, step):
if step < self.warmup_steps:
return self.base_lr * (step + 1) / self.warmup_steps
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
def main():
parser = argparse.ArgumentParser(description="Ternary Bonsai Training")
parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B")
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--seq-len", type=int, default=256)
parser.add_argument("--steps", type=int, default=200)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--min-lr", type=float, default=5e-6)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--eval-every", type=int, default=50)
parser.add_argument("--save-path", default="./ternary_trained")
args = parser.parse_args()
print("=" * 70)
print("TERNARY BONSAI TRAINING")
print("=" * 70)
print(f"Model: {args.model_name}")
print(f"Steps: {args.steps}, Batch size: {args.batch_size}, Seq len: {args.seq_len}")
print(f"LR: {args.lr}, Warmup: {args.warmup}, Weight decay: {args.weight_decay}")
print()
# Step 1: Load and convert model
print("[1/5] Loading Qwen3-0.6B...")
from mlx_lm import load
src_model, tokenizer = load(args.model_name)
src_args = src_model.args
config = ModelArgs(
model_type=src_args.model_type,
hidden_size=src_args.hidden_size,
num_hidden_layers=src_args.num_hidden_layers,
intermediate_size=src_args.intermediate_size,
num_attention_heads=src_args.num_attention_heads,
rms_norm_eps=src_args.rms_norm_eps,
vocab_size=src_args.vocab_size,
num_key_value_heads=src_args.num_key_value_heads,
max_position_embeddings=src_args.max_position_embeddings,
rope_theta=src_args.rope_theta,
head_dim=src_args.head_dim,
tie_word_embeddings=src_args.tie_word_embeddings,
rope_scaling=src_args.rope_scaling,
)
print(f"Config: hidden_size={config.hidden_size}, layers={config.num_hidden_layers}, "
f"heads={config.num_attention_heads}, kv_heads={config.num_key_value_heads}")
print("\n[2/5] Creating ternary model and copying weights...")
model = TernaryModel(config)
convert_weights(src_model, model)
del src_model
mx.clear_cache()
# Verify ternary projection before training
print("\n[3/5] Pre-training ternary check...")
all_ok, results = verify_ternary(model)
print(f" All weights ternary: {all_ok}")
if all_ok:
for name, r in list(results.items())[:3]:
d = r['distribution']
print(f" {name}: shape={r['shape']}, "
f"-1:{d[-1]:.3f}, 0:{d[0]:.3f}, +1:{d[1]:.3f}")
# Load training data
print("\n[4/5] Loading WikiText-2 dataset...")
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
train_text = "\n".join(dataset["train"]["text"])
val_text = "\n".join(dataset["validation"]["text"])
train_tokens = tokenizer.encode(train_text)
val_tokens = tokenizer.encode(val_text)
print(f" Train tokens: {len(train_tokens):,}")
print(f" Val tokens: {len(val_tokens):,}")
# Create training sequences
seq_len = args.seq_len
n_train_seqs = len(train_tokens) // (seq_len + 1)
n_val_seqs = min(200, len(val_tokens) // (seq_len + 1))
train_sequences = []
for i in range(0, n_train_seqs * (seq_len + 1), seq_len + 1):
train_sequences.append(train_tokens[i:i + seq_len + 1])
val_sequences = []
for i in range(0, n_val_seqs * (seq_len + 1), seq_len + 1):
val_sequences.append(val_tokens[i:i + seq_len + 1])
n_train = len(train_sequences)
n_val = len(val_sequences)
print(f" Train sequences: {n_train:,}")
print(f" Val sequences: {n_val:,}")
# Training loop
print(f"\n[5/5] Training for {args.steps} steps...\n")
lr_schedule = LRSchedule(args.lr, args.warmup, args.steps, args.min_lr)
optimizer = AdamW(learning_rate=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95))
def loss_fn(model, batch):
input_ids = mx.array(batch[:, :-1])
targets = mx.array(batch[:, 1:])
logits = model(input_ids)
return nn.losses.cross_entropy(logits, targets, reduction="mean")
def clip_grad_norm(grads, max_norm=1.0):
"""Clip gradient norm to prevent explosion."""
total_norm_sq = mx.array(0.0)
flat = nn.utils.tree_flatten(grads)
for _, g in flat:
if isinstance(g, mx.array) and g.ndim >= 1:
total_norm_sq = total_norm_sq + mx.sum(g ** 2)
total_norm = mx.sqrt(total_norm_sq)
scale = mx.where(total_norm > max_norm, max_norm / (total_norm + 1e-6), mx.array(1.0))
# Scale all gradients
clipped = nn.utils.tree_map(lambda g: g * scale if isinstance(g, mx.array) and g.ndim >= 1 else g, grads)
return clipped, float(total_norm)
step = 0
losses = []
start_time = time.time()
for epoch in range(100):
if step >= args.steps:
break
indices = np.random.permutation(n_train)
for i in range(0, n_train, args.batch_size):
if step >= args.steps:
break
batch_indices = indices[i:i + args.batch_size]
if len(batch_indices) < args.batch_size:
continue
batch = np.array([train_sequences[j] for j in batch_indices])
current_lr = lr_schedule(step)
optimizer.learning_rate = current_lr
loss, grads = nn.value_and_grad(model, lambda m: loss_fn(m, batch))(model)
# Gradient clipping to prevent explosion
grads, grad_norm = clip_grad_norm(grads, max_norm=1.0)
optimizer.update(model, grads)
mx.eval(loss)
losses.append(float(loss))
step += 1
if step % args.log_every == 0:
recent = losses[-args.log_every:]
avg_loss = np.mean(recent)
elapsed = time.time() - start_time
toks_per_sec = args.log_every * args.batch_size * seq_len / max(elapsed, 0.001)
print(f" Step {step:4d}/{args.steps} | Loss: {avg_loss:.4f} | "
f"GradNorm: {grad_norm:.1f} | LR: {current_lr:.2e} | Tok/s: {toks_per_sec:.0f}")
start_time = time.time()
if step % args.eval_every == 0 and step > 0:
val_indices = np.random.choice(n_val, size=min(args.batch_size, n_val), replace=False)
val_batch = np.array([val_sequences[j] for j in val_indices])
val_loss = loss_fn(model, val_batch)
mx.eval(val_loss)
val_ppl = math.exp(float(val_loss))
print(f" >> Eval at step {step}: val_loss={float(val_loss):.4f}, val_ppl={val_ppl:.1f}")
all_ok, _ = verify_ternary(model)
print(f" Ternary check: {'PASS' if all_ok else 'FAIL'}")
# Final evaluation
print("\n" + "=" * 70)
print("FINAL EVALUATION")
print("=" * 70)
all_ok, results = verify_ternary(model)
print(f"\n1. TERNARY VERIFICATION: {'PASS' if all_ok else 'FAIL'}")
for name, r in sorted(results.items()):
d = r['distribution']
status = "OK" if r['is_ternary'] else "FAIL"
print(f" [{status}] {name}: shape={r['shape']}, "
f"-1:{d[-1]:.3f}, 0:{d[0]:.3f}, +1:{d[1]:.3f}")
# Validation perplexity
print("\n2. PERPLEXITY EVALUATION:")
eval_batch_size = min(4, n_val)
val_losses_list = []
for i in range(0, min(n_val - eval_batch_size, 50), eval_batch_size):
batch = np.array(val_sequences[i:i + eval_batch_size])
if len(batch) < eval_batch_size:
continue
vl = loss_fn(model, batch)
mx.eval(vl)
val_losses_list.append(float(vl))
avg_val_loss = np.mean(val_losses_list) if val_losses_list else float('inf')
vocab_size = config.vocab_size
random_loss = math.log(vocab_size)
print(f" Train loss (last 50): {np.mean(losses[-50:]):.4f}")
print(f" Val loss: {avg_val_loss:.4f}")
print(f" Val perplexity: {math.exp(avg_val_loss):.1f}")
print(f" Random baseline: perplexity={vocab_size} (loss={random_loss:.2f})")
# Text generation
print("\n3. TEXT GENERATION:")
prompts = [
"The history of the United States",
"In the year 2024,",
"The most important thing about",
"Scientists discovered that",
]
for prompt in prompts:
try:
generated = generate_text(model, tokenizer, prompt, max_tokens=60)
print(f" Prompt: {prompt}")
print(f" Output: {generated[:200]}")
print()
except Exception as e:
print(f" Generation failed for '{prompt}': {e}")
# Save
if args.save_path:
os.makedirs(args.save_path, exist_ok=True)
print(f"\nSaving model to {args.save_path}...")
params = collect_all_params(model)
if params:
mx.save_safetensors(
os.path.join(args.save_path, "weights.safetensors"),
params
)
print(f"Saved {len(params)} weight tensors.")
else:
print("WARNING: No weights collected for saving!")
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"Ternary projection verified: {all_ok}")
print(f"Final training loss: {np.mean(losses[-50:]):.4f}")
print(f"Validation perplexity: {math.exp(avg_val_loss):.1f}")
print(f"(Random baseline: {vocab_size})")
print()
print("Engineering notes:")
print(" - Group size = 128 (balances granularity and representation)")
print(" - Scale = mean(|W|) per group (better than max for sparse distributions)")
print(" - STE gradient: identity pass-through (standard BitNet approach)")
print(f" - Learning rate: {args.lr} with {args.warmup} warmup steps")
print(f" - AdamW with weight_decay={args.weight_decay}")
if __name__ == "__main__":
main()