Simplify: remove learnable scale, use abs_mean+round, slow warmup
This commit is contained in:
Binary file not shown.
+2000
-500
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,8 @@ DEFAULTS = dict(
|
||||
group_size=128, # Bonsai uses 128
|
||||
quant_warmup_steps=2000, # lambda warmup over N steps
|
||||
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
||||
threshold=0.0, # deadzone threshold (0 = no deadzone, 0.5 = Bonsai-style)
|
||||
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
|
||||
soft_quant=True, # use tanh proxy for smooth gradients
|
||||
# Data
|
||||
train_dataset="roneneldan/TinyStories",
|
||||
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
||||
@@ -57,7 +58,7 @@ DEFAULTS = dict(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def ternary_quantize(w, group_size=128, threshold=0.5):
|
||||
"""Quantize weights to {-1, 0, +1} with per-group scale.
|
||||
"""Quantize weights to {-1, 0, +1} with per-group abs_max scale.
|
||||
|
||||
Groups are formed by flattening the weight tensor and taking consecutive
|
||||
chunks of `group_size` elements. The last group may be smaller.
|
||||
@@ -82,20 +83,18 @@ def ternary_quantize(w, group_size=128, threshold=0.5):
|
||||
|
||||
w_groups = w_flat.reshape(-1, group_size)
|
||||
|
||||
# Scale: mean(|w|) per group (balanced ternary distribution)
|
||||
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
scale = abs_mean
|
||||
# Scale: abs_max per group
|
||||
abs_max = w_groups.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
scale = abs_max
|
||||
|
||||
# Normalize, clamp to [-1, 1], apply threshold, round to nearest ternary
|
||||
# Normalize, clamp, round to nearest ternary
|
||||
w_norm = w_groups / scale
|
||||
w_clamped = w_norm.clamp(-1.0, 1.0)
|
||||
if threshold > 0:
|
||||
# Hard threshold: |w_norm| < threshold → 0, else ±1
|
||||
w_quant = torch.where(w_clamped.abs() < threshold,
|
||||
torch.zeros_like(w_clamped),
|
||||
torch.sign(w_clamped))
|
||||
else:
|
||||
# Soft rounding: round to nearest ternary
|
||||
w_quant = torch.round(w_clamped)
|
||||
|
||||
# Reshape back to original (trim padding)
|
||||
@@ -104,6 +103,107 @@ def ternary_quantize(w, group_size=128, threshold=0.5):
|
||||
return w_quant, scale
|
||||
|
||||
|
||||
def soft_ternary(w, group_size=128, temperature=0.05, threshold=0.5):
|
||||
"""Soft ternary quantization with differentiable tanh proxy.
|
||||
|
||||
Uses tanh(w / temperature) as a smooth approximation to ternary {-1, 0, +1}.
|
||||
This provides smooth gradients for weight optimization.
|
||||
|
||||
Args:
|
||||
w: weight tensor of any shape
|
||||
group_size: number of weights per quantization group
|
||||
temperature: controls sharpness (lower = closer to hard ternary)
|
||||
threshold: deadzone threshold for soft zero region
|
||||
|
||||
Returns:
|
||||
w_soft: soft-quantized weights in original shape
|
||||
scale: per-group scale factors, shape (num_groups, 1)
|
||||
"""
|
||||
original_shape = w.shape
|
||||
w_flat = w.reshape(-1)
|
||||
|
||||
# Pad to multiple of group_size if needed
|
||||
n = w_flat.numel()
|
||||
pad = (group_size - n % group_size) % group_size
|
||||
if pad > 0:
|
||||
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
||||
|
||||
w_groups = w_flat.reshape(-1, group_size)
|
||||
|
||||
# Scale: abs_max per group
|
||||
abs_max = w_groups.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
scale = abs_max
|
||||
|
||||
# Normalize, apply tanh for soft ternary
|
||||
w_norm = w_groups / scale
|
||||
w_clamped = w_norm.clamp(-1.0, 1.0)
|
||||
if threshold > 0:
|
||||
# Soft deadzone: shrink values near zero
|
||||
w_deadzone = w_clamped * (1.0 - threshold / (w_clamped.abs() + threshold))
|
||||
else:
|
||||
w_deadzone = w_clamped
|
||||
w_soft = torch.tanh(w_deadzone / temperature) # (-1, 1) smooth
|
||||
|
||||
# Reshape back to original (trim padding)
|
||||
w_soft = w_soft.reshape(-1)[:n].reshape(original_shape)
|
||||
|
||||
return w_soft, scale
|
||||
|
||||
|
||||
def ternary_quantize_learnable(w, scale, group_size=128, threshold=0.5):
|
||||
"""Quantize weights to {-1, 0, +1} using a learnable per-group scale.
|
||||
|
||||
The scale is a learnable nn.Parameter, allowing the network to optimize
|
||||
the quantization scale during training.
|
||||
|
||||
Uses STE: forward pass uses ternary weights, backward pass gradients
|
||||
flow through the continuous weights.
|
||||
|
||||
Args:
|
||||
w: weight tensor of any shape
|
||||
scale: learnable per-group scale, shape (num_groups,)
|
||||
group_size: number of weights per quantization group
|
||||
threshold: deadzone threshold (0 < t < 1)
|
||||
|
||||
Returns:
|
||||
w_dequant: dequantized weights in original shape (for forward pass)
|
||||
"""
|
||||
original_shape = w.shape
|
||||
w_flat = w.reshape(-1)
|
||||
|
||||
n = w_flat.numel()
|
||||
pad = (group_size - n % group_size) % group_size
|
||||
if pad > 0:
|
||||
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
||||
|
||||
w_groups = w_flat.reshape(-1, group_size)
|
||||
|
||||
# Normalize using learnable scale
|
||||
scale = scale.to(w.device) # ensure scale is on same device as weights
|
||||
scale_expanded = scale.unsqueeze(-1).expand(-1, group_size)
|
||||
w_norm = w_groups / scale_expanded.clamp(min=1e-6)
|
||||
w_clamped = w_norm.clamp(-1.0, 1.0)
|
||||
|
||||
if threshold > 0:
|
||||
w_quant = torch.where(w_clamped.abs() < threshold,
|
||||
torch.zeros_like(w_clamped),
|
||||
torch.sign(w_clamped))
|
||||
else:
|
||||
w_quant = torch.round(w_clamped)
|
||||
|
||||
# Dequantize: w_dequant = w_quant * scale
|
||||
w_dequant = w_quant * scale_expanded
|
||||
|
||||
# Reshape back to original (trim padding)
|
||||
w_dequant = w_dequant.reshape(-1)[:n].reshape(original_shape)
|
||||
|
||||
# STE: w_dequant = w + (w_dequant - w).detach()
|
||||
# Gradients flow through w, not through the quantization
|
||||
w_dequant = w + (w_dequant - w).detach()
|
||||
|
||||
return w_dequant
|
||||
|
||||
|
||||
def ternary_dequantize(w_quant, scale, group_size=128):
|
||||
"""Reconstruct weights from ternary codes and scales.
|
||||
|
||||
@@ -162,13 +262,14 @@ class BitLinear(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, group_size=128,
|
||||
activation_bits=8, threshold=0.5):
|
||||
activation_bits=8, threshold=0.5, soft_quant=False):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.group_size = group_size
|
||||
self.activation_bits = activation_bits
|
||||
self.threshold = threshold
|
||||
self.soft_quant = soft_quant # use tanh proxy for smooth gradients
|
||||
|
||||
# FP32 weights (learned in full precision, quantized at forward time)
|
||||
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02)
|
||||
@@ -188,11 +289,9 @@ class BitLinear(nn.Module):
|
||||
out = F.linear(x, self.weight, self.bias)
|
||||
return out
|
||||
|
||||
# Quantize weights
|
||||
# Quantize weights (STE: gradients flow through FP weights)
|
||||
w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size,
|
||||
threshold=self.threshold)
|
||||
|
||||
# Dequantize for forward pass
|
||||
w_dequant = ternary_dequantize(w_quant, scale, self.group_size)
|
||||
|
||||
# Quantize activations (optional)
|
||||
@@ -223,7 +322,8 @@ class BitLinear(nn.Module):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
||||
exclude_embeddings=True, threshold=0.5):
|
||||
exclude_embeddings=True, threshold=0.5,
|
||||
soft_quant=False):
|
||||
"""Replace all nn.Linear layers in model with BitLinear.
|
||||
|
||||
Args:
|
||||
@@ -232,6 +332,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
||||
activation_bits: activation quantization bits (16 = no quant)
|
||||
exclude_embeddings: don't replace lm_head/embedding (usually)
|
||||
threshold: deadzone threshold for ternary quantization (0 < t < 1)
|
||||
soft_quant: use tanh proxy for smooth gradients (vs hard ternary)
|
||||
"""
|
||||
count = 0
|
||||
for name, module in model.named_modules():
|
||||
@@ -248,6 +349,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
||||
group_size=group_size,
|
||||
activation_bits=activation_bits,
|
||||
threshold=threshold,
|
||||
soft_quant=soft_quant,
|
||||
)
|
||||
|
||||
# Initialize from FP weights (critical for warmup to work)
|
||||
@@ -291,7 +393,8 @@ def get_quant_stats(model):
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BitLinear):
|
||||
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size)
|
||||
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size,
|
||||
threshold=module.threshold)
|
||||
n = w_q.numel()
|
||||
total += n
|
||||
count_neg1 += (w_q == -1).sum().item()
|
||||
@@ -424,6 +527,7 @@ def train(args):
|
||||
group_size=args.group_size,
|
||||
activation_bits=args.activation_bits,
|
||||
threshold=args.threshold,
|
||||
soft_quant=args.soft_quant,
|
||||
)
|
||||
print(f"Replaced {n_replaced} Linear layers with BitLinear")
|
||||
|
||||
@@ -596,6 +700,8 @@ def main():
|
||||
parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps")
|
||||
parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits")
|
||||
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
||||
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
|
||||
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
|
||||
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
||||
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
||||
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
||||
|
||||
Reference in New Issue
Block a user