Simplify: remove learnable scale, use abs_mean+round, slow warmup

This commit is contained in:
2026-04-24 04:10:21 +02:00
parent 27e9faf4f5
commit 853019baf2
3 changed files with 2120 additions and 514 deletions
Binary file not shown.
+2000 -500
View File
File diff suppressed because it is too large Load Diff
+120 -14
View File
@@ -44,7 +44,8 @@ DEFAULTS = dict(
group_size=128, # Bonsai uses 128 group_size=128, # Bonsai uses 128
quant_warmup_steps=2000, # lambda warmup over N steps quant_warmup_steps=2000, # lambda warmup over N steps
activation_bits=16, # 16 = no activation quant (use 8 for INT8) activation_bits=16, # 16 = no activation quant (use 8 for INT8)
threshold=0.0, # deadzone threshold (0 = no deadzone, 0.5 = Bonsai-style) threshold=0.5, # deadzone threshold (Bonsai: 0.5)
soft_quant=True, # use tanh proxy for smooth gradients
# Data # Data
train_dataset="roneneldan/TinyStories", train_dataset="roneneldan/TinyStories",
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"), eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
@@ -57,7 +58,7 @@ DEFAULTS = dict(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def ternary_quantize(w, group_size=128, threshold=0.5): def ternary_quantize(w, group_size=128, threshold=0.5):
"""Quantize weights to {-1, 0, +1} with per-group scale. """Quantize weights to {-1, 0, +1} with per-group abs_max scale.
Groups are formed by flattening the weight tensor and taking consecutive Groups are formed by flattening the weight tensor and taking consecutive
chunks of `group_size` elements. The last group may be smaller. chunks of `group_size` elements. The last group may be smaller.
@@ -82,20 +83,18 @@ def ternary_quantize(w, group_size=128, threshold=0.5):
w_groups = w_flat.reshape(-1, group_size) w_groups = w_flat.reshape(-1, group_size)
# Scale: mean(|w|) per group (balanced ternary distribution) # Scale: abs_max per group
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6) abs_max = w_groups.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
scale = abs_mean scale = abs_max
# Normalize, clamp to [-1, 1], apply threshold, round to nearest ternary # Normalize, clamp, round to nearest ternary
w_norm = w_groups / scale w_norm = w_groups / scale
w_clamped = w_norm.clamp(-1.0, 1.0) w_clamped = w_norm.clamp(-1.0, 1.0)
if threshold > 0: if threshold > 0:
# Hard threshold: |w_norm| < threshold → 0, else ±1
w_quant = torch.where(w_clamped.abs() < threshold, w_quant = torch.where(w_clamped.abs() < threshold,
torch.zeros_like(w_clamped), torch.zeros_like(w_clamped),
torch.sign(w_clamped)) torch.sign(w_clamped))
else: else:
# Soft rounding: round to nearest ternary
w_quant = torch.round(w_clamped) w_quant = torch.round(w_clamped)
# Reshape back to original (trim padding) # Reshape back to original (trim padding)
@@ -104,6 +103,107 @@ def ternary_quantize(w, group_size=128, threshold=0.5):
return w_quant, scale return w_quant, scale
def soft_ternary(w, group_size=128, temperature=0.05, threshold=0.5):
"""Soft ternary quantization with differentiable tanh proxy.
Uses tanh(w / temperature) as a smooth approximation to ternary {-1, 0, +1}.
This provides smooth gradients for weight optimization.
Args:
w: weight tensor of any shape
group_size: number of weights per quantization group
temperature: controls sharpness (lower = closer to hard ternary)
threshold: deadzone threshold for soft zero region
Returns:
w_soft: soft-quantized weights in original shape
scale: per-group scale factors, shape (num_groups, 1)
"""
original_shape = w.shape
w_flat = w.reshape(-1)
# Pad to multiple of group_size if needed
n = w_flat.numel()
pad = (group_size - n % group_size) % group_size
if pad > 0:
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
w_groups = w_flat.reshape(-1, group_size)
# Scale: abs_max per group
abs_max = w_groups.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
scale = abs_max
# Normalize, apply tanh for soft ternary
w_norm = w_groups / scale
w_clamped = w_norm.clamp(-1.0, 1.0)
if threshold > 0:
# Soft deadzone: shrink values near zero
w_deadzone = w_clamped * (1.0 - threshold / (w_clamped.abs() + threshold))
else:
w_deadzone = w_clamped
w_soft = torch.tanh(w_deadzone / temperature) # (-1, 1) smooth
# Reshape back to original (trim padding)
w_soft = w_soft.reshape(-1)[:n].reshape(original_shape)
return w_soft, scale
def ternary_quantize_learnable(w, scale, group_size=128, threshold=0.5):
"""Quantize weights to {-1, 0, +1} using a learnable per-group scale.
The scale is a learnable nn.Parameter, allowing the network to optimize
the quantization scale during training.
Uses STE: forward pass uses ternary weights, backward pass gradients
flow through the continuous weights.
Args:
w: weight tensor of any shape
scale: learnable per-group scale, shape (num_groups,)
group_size: number of weights per quantization group
threshold: deadzone threshold (0 < t < 1)
Returns:
w_dequant: dequantized weights in original shape (for forward pass)
"""
original_shape = w.shape
w_flat = w.reshape(-1)
n = w_flat.numel()
pad = (group_size - n % group_size) % group_size
if pad > 0:
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
w_groups = w_flat.reshape(-1, group_size)
# Normalize using learnable scale
scale = scale.to(w.device) # ensure scale is on same device as weights
scale_expanded = scale.unsqueeze(-1).expand(-1, group_size)
w_norm = w_groups / scale_expanded.clamp(min=1e-6)
w_clamped = w_norm.clamp(-1.0, 1.0)
if threshold > 0:
w_quant = torch.where(w_clamped.abs() < threshold,
torch.zeros_like(w_clamped),
torch.sign(w_clamped))
else:
w_quant = torch.round(w_clamped)
# Dequantize: w_dequant = w_quant * scale
w_dequant = w_quant * scale_expanded
# Reshape back to original (trim padding)
w_dequant = w_dequant.reshape(-1)[:n].reshape(original_shape)
# STE: w_dequant = w + (w_dequant - w).detach()
# Gradients flow through w, not through the quantization
w_dequant = w + (w_dequant - w).detach()
return w_dequant
def ternary_dequantize(w_quant, scale, group_size=128): def ternary_dequantize(w_quant, scale, group_size=128):
"""Reconstruct weights from ternary codes and scales. """Reconstruct weights from ternary codes and scales.
@@ -162,13 +262,14 @@ class BitLinear(nn.Module):
""" """
def __init__(self, in_features, out_features, bias=True, group_size=128, def __init__(self, in_features, out_features, bias=True, group_size=128,
activation_bits=8, threshold=0.5): activation_bits=8, threshold=0.5, soft_quant=False):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.group_size = group_size self.group_size = group_size
self.activation_bits = activation_bits self.activation_bits = activation_bits
self.threshold = threshold self.threshold = threshold
self.soft_quant = soft_quant # use tanh proxy for smooth gradients
# FP32 weights (learned in full precision, quantized at forward time) # FP32 weights (learned in full precision, quantized at forward time)
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02) self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02)
@@ -188,11 +289,9 @@ class BitLinear(nn.Module):
out = F.linear(x, self.weight, self.bias) out = F.linear(x, self.weight, self.bias)
return out return out
# Quantize weights # Quantize weights (STE: gradients flow through FP weights)
w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size, w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size,
threshold=self.threshold) threshold=self.threshold)
# Dequantize for forward pass
w_dequant = ternary_dequantize(w_quant, scale, self.group_size) w_dequant = ternary_dequantize(w_quant, scale, self.group_size)
# Quantize activations (optional) # Quantize activations (optional)
@@ -223,7 +322,8 @@ class BitLinear(nn.Module):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8, def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
exclude_embeddings=True, threshold=0.5): exclude_embeddings=True, threshold=0.5,
soft_quant=False):
"""Replace all nn.Linear layers in model with BitLinear. """Replace all nn.Linear layers in model with BitLinear.
Args: Args:
@@ -232,6 +332,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
activation_bits: activation quantization bits (16 = no quant) activation_bits: activation quantization bits (16 = no quant)
exclude_embeddings: don't replace lm_head/embedding (usually) exclude_embeddings: don't replace lm_head/embedding (usually)
threshold: deadzone threshold for ternary quantization (0 < t < 1) threshold: deadzone threshold for ternary quantization (0 < t < 1)
soft_quant: use tanh proxy for smooth gradients (vs hard ternary)
""" """
count = 0 count = 0
for name, module in model.named_modules(): for name, module in model.named_modules():
@@ -248,6 +349,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
group_size=group_size, group_size=group_size,
activation_bits=activation_bits, activation_bits=activation_bits,
threshold=threshold, threshold=threshold,
soft_quant=soft_quant,
) )
# Initialize from FP weights (critical for warmup to work) # Initialize from FP weights (critical for warmup to work)
@@ -291,7 +393,8 @@ def get_quant_stats(model):
for module in model.modules(): for module in model.modules():
if isinstance(module, BitLinear): if isinstance(module, BitLinear):
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size) w_q, _ = ternary_quantize(module.weight, group_size=module.group_size,
threshold=module.threshold)
n = w_q.numel() n = w_q.numel()
total += n total += n
count_neg1 += (w_q == -1).sum().item() count_neg1 += (w_q == -1).sum().item()
@@ -424,6 +527,7 @@ def train(args):
group_size=args.group_size, group_size=args.group_size,
activation_bits=args.activation_bits, activation_bits=args.activation_bits,
threshold=args.threshold, threshold=args.threshold,
soft_quant=args.soft_quant,
) )
print(f"Replaced {n_replaced} Linear layers with BitLinear") print(f"Replaced {n_replaced} Linear layers with BitLinear")
@@ -596,6 +700,8 @@ def main():
parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps") parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps")
parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits") parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits")
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold") parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset") parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path") parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
parser.add_argument("--log-file", default=DEFAULTS["log_file"]) parser.add_argument("--log-file", default=DEFAULTS["log_file"])