Simplify: remove learnable scale, use abs_mean+round, slow warmup
This commit is contained in:
Binary file not shown.
+2000
-500
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,8 @@ DEFAULTS = dict(
|
|||||||
group_size=128, # Bonsai uses 128
|
group_size=128, # Bonsai uses 128
|
||||||
quant_warmup_steps=2000, # lambda warmup over N steps
|
quant_warmup_steps=2000, # lambda warmup over N steps
|
||||||
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
activation_bits=16, # 16 = no activation quant (use 8 for INT8)
|
||||||
threshold=0.0, # deadzone threshold (0 = no deadzone, 0.5 = Bonsai-style)
|
threshold=0.5, # deadzone threshold (Bonsai: 0.5)
|
||||||
|
soft_quant=True, # use tanh proxy for smooth gradients
|
||||||
# Data
|
# Data
|
||||||
train_dataset="roneneldan/TinyStories",
|
train_dataset="roneneldan/TinyStories",
|
||||||
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
eval_data_path=str(Path(__file__).parent / "data" / "wikitext_eval.json"),
|
||||||
@@ -57,7 +58,7 @@ DEFAULTS = dict(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def ternary_quantize(w, group_size=128, threshold=0.5):
|
def ternary_quantize(w, group_size=128, threshold=0.5):
|
||||||
"""Quantize weights to {-1, 0, +1} with per-group scale.
|
"""Quantize weights to {-1, 0, +1} with per-group abs_max scale.
|
||||||
|
|
||||||
Groups are formed by flattening the weight tensor and taking consecutive
|
Groups are formed by flattening the weight tensor and taking consecutive
|
||||||
chunks of `group_size` elements. The last group may be smaller.
|
chunks of `group_size` elements. The last group may be smaller.
|
||||||
@@ -82,20 +83,18 @@ def ternary_quantize(w, group_size=128, threshold=0.5):
|
|||||||
|
|
||||||
w_groups = w_flat.reshape(-1, group_size)
|
w_groups = w_flat.reshape(-1, group_size)
|
||||||
|
|
||||||
# Scale: mean(|w|) per group (balanced ternary distribution)
|
# Scale: abs_max per group
|
||||||
abs_mean = w_groups.abs().mean(dim=-1, keepdim=True).clamp(min=1e-6)
|
abs_max = w_groups.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||||
scale = abs_mean
|
scale = abs_max
|
||||||
|
|
||||||
# Normalize, clamp to [-1, 1], apply threshold, round to nearest ternary
|
# Normalize, clamp, round to nearest ternary
|
||||||
w_norm = w_groups / scale
|
w_norm = w_groups / scale
|
||||||
w_clamped = w_norm.clamp(-1.0, 1.0)
|
w_clamped = w_norm.clamp(-1.0, 1.0)
|
||||||
if threshold > 0:
|
if threshold > 0:
|
||||||
# Hard threshold: |w_norm| < threshold → 0, else ±1
|
|
||||||
w_quant = torch.where(w_clamped.abs() < threshold,
|
w_quant = torch.where(w_clamped.abs() < threshold,
|
||||||
torch.zeros_like(w_clamped),
|
torch.zeros_like(w_clamped),
|
||||||
torch.sign(w_clamped))
|
torch.sign(w_clamped))
|
||||||
else:
|
else:
|
||||||
# Soft rounding: round to nearest ternary
|
|
||||||
w_quant = torch.round(w_clamped)
|
w_quant = torch.round(w_clamped)
|
||||||
|
|
||||||
# Reshape back to original (trim padding)
|
# Reshape back to original (trim padding)
|
||||||
@@ -104,6 +103,107 @@ def ternary_quantize(w, group_size=128, threshold=0.5):
|
|||||||
return w_quant, scale
|
return w_quant, scale
|
||||||
|
|
||||||
|
|
||||||
|
def soft_ternary(w, group_size=128, temperature=0.05, threshold=0.5):
|
||||||
|
"""Soft ternary quantization with differentiable tanh proxy.
|
||||||
|
|
||||||
|
Uses tanh(w / temperature) as a smooth approximation to ternary {-1, 0, +1}.
|
||||||
|
This provides smooth gradients for weight optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w: weight tensor of any shape
|
||||||
|
group_size: number of weights per quantization group
|
||||||
|
temperature: controls sharpness (lower = closer to hard ternary)
|
||||||
|
threshold: deadzone threshold for soft zero region
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
w_soft: soft-quantized weights in original shape
|
||||||
|
scale: per-group scale factors, shape (num_groups, 1)
|
||||||
|
"""
|
||||||
|
original_shape = w.shape
|
||||||
|
w_flat = w.reshape(-1)
|
||||||
|
|
||||||
|
# Pad to multiple of group_size if needed
|
||||||
|
n = w_flat.numel()
|
||||||
|
pad = (group_size - n % group_size) % group_size
|
||||||
|
if pad > 0:
|
||||||
|
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
||||||
|
|
||||||
|
w_groups = w_flat.reshape(-1, group_size)
|
||||||
|
|
||||||
|
# Scale: abs_max per group
|
||||||
|
abs_max = w_groups.abs().amax(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||||
|
scale = abs_max
|
||||||
|
|
||||||
|
# Normalize, apply tanh for soft ternary
|
||||||
|
w_norm = w_groups / scale
|
||||||
|
w_clamped = w_norm.clamp(-1.0, 1.0)
|
||||||
|
if threshold > 0:
|
||||||
|
# Soft deadzone: shrink values near zero
|
||||||
|
w_deadzone = w_clamped * (1.0 - threshold / (w_clamped.abs() + threshold))
|
||||||
|
else:
|
||||||
|
w_deadzone = w_clamped
|
||||||
|
w_soft = torch.tanh(w_deadzone / temperature) # (-1, 1) smooth
|
||||||
|
|
||||||
|
# Reshape back to original (trim padding)
|
||||||
|
w_soft = w_soft.reshape(-1)[:n].reshape(original_shape)
|
||||||
|
|
||||||
|
return w_soft, scale
|
||||||
|
|
||||||
|
|
||||||
|
def ternary_quantize_learnable(w, scale, group_size=128, threshold=0.5):
|
||||||
|
"""Quantize weights to {-1, 0, +1} using a learnable per-group scale.
|
||||||
|
|
||||||
|
The scale is a learnable nn.Parameter, allowing the network to optimize
|
||||||
|
the quantization scale during training.
|
||||||
|
|
||||||
|
Uses STE: forward pass uses ternary weights, backward pass gradients
|
||||||
|
flow through the continuous weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w: weight tensor of any shape
|
||||||
|
scale: learnable per-group scale, shape (num_groups,)
|
||||||
|
group_size: number of weights per quantization group
|
||||||
|
threshold: deadzone threshold (0 < t < 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
w_dequant: dequantized weights in original shape (for forward pass)
|
||||||
|
"""
|
||||||
|
original_shape = w.shape
|
||||||
|
w_flat = w.reshape(-1)
|
||||||
|
|
||||||
|
n = w_flat.numel()
|
||||||
|
pad = (group_size - n % group_size) % group_size
|
||||||
|
if pad > 0:
|
||||||
|
w_flat = torch.cat([w_flat, w_flat.new_zeros(pad)])
|
||||||
|
|
||||||
|
w_groups = w_flat.reshape(-1, group_size)
|
||||||
|
|
||||||
|
# Normalize using learnable scale
|
||||||
|
scale = scale.to(w.device) # ensure scale is on same device as weights
|
||||||
|
scale_expanded = scale.unsqueeze(-1).expand(-1, group_size)
|
||||||
|
w_norm = w_groups / scale_expanded.clamp(min=1e-6)
|
||||||
|
w_clamped = w_norm.clamp(-1.0, 1.0)
|
||||||
|
|
||||||
|
if threshold > 0:
|
||||||
|
w_quant = torch.where(w_clamped.abs() < threshold,
|
||||||
|
torch.zeros_like(w_clamped),
|
||||||
|
torch.sign(w_clamped))
|
||||||
|
else:
|
||||||
|
w_quant = torch.round(w_clamped)
|
||||||
|
|
||||||
|
# Dequantize: w_dequant = w_quant * scale
|
||||||
|
w_dequant = w_quant * scale_expanded
|
||||||
|
|
||||||
|
# Reshape back to original (trim padding)
|
||||||
|
w_dequant = w_dequant.reshape(-1)[:n].reshape(original_shape)
|
||||||
|
|
||||||
|
# STE: w_dequant = w + (w_dequant - w).detach()
|
||||||
|
# Gradients flow through w, not through the quantization
|
||||||
|
w_dequant = w + (w_dequant - w).detach()
|
||||||
|
|
||||||
|
return w_dequant
|
||||||
|
|
||||||
|
|
||||||
def ternary_dequantize(w_quant, scale, group_size=128):
|
def ternary_dequantize(w_quant, scale, group_size=128):
|
||||||
"""Reconstruct weights from ternary codes and scales.
|
"""Reconstruct weights from ternary codes and scales.
|
||||||
|
|
||||||
@@ -162,13 +262,14 @@ class BitLinear(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, group_size=128,
|
def __init__(self, in_features, out_features, bias=True, group_size=128,
|
||||||
activation_bits=8, threshold=0.5):
|
activation_bits=8, threshold=0.5, soft_quant=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.activation_bits = activation_bits
|
self.activation_bits = activation_bits
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
self.soft_quant = soft_quant # use tanh proxy for smooth gradients
|
||||||
|
|
||||||
# FP32 weights (learned in full precision, quantized at forward time)
|
# FP32 weights (learned in full precision, quantized at forward time)
|
||||||
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02)
|
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=DEFAULTS['dtype']) * 0.02)
|
||||||
@@ -188,11 +289,9 @@ class BitLinear(nn.Module):
|
|||||||
out = F.linear(x, self.weight, self.bias)
|
out = F.linear(x, self.weight, self.bias)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# Quantize weights
|
# Quantize weights (STE: gradients flow through FP weights)
|
||||||
w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size,
|
w_quant, scale = ternary_quantize(self.weight, group_size=self.group_size,
|
||||||
threshold=self.threshold)
|
threshold=self.threshold)
|
||||||
|
|
||||||
# Dequantize for forward pass
|
|
||||||
w_dequant = ternary_dequantize(w_quant, scale, self.group_size)
|
w_dequant = ternary_dequantize(w_quant, scale, self.group_size)
|
||||||
|
|
||||||
# Quantize activations (optional)
|
# Quantize activations (optional)
|
||||||
@@ -223,7 +322,8 @@ class BitLinear(nn.Module):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
||||||
exclude_embeddings=True, threshold=0.5):
|
exclude_embeddings=True, threshold=0.5,
|
||||||
|
soft_quant=False):
|
||||||
"""Replace all nn.Linear layers in model with BitLinear.
|
"""Replace all nn.Linear layers in model with BitLinear.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -232,6 +332,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
|||||||
activation_bits: activation quantization bits (16 = no quant)
|
activation_bits: activation quantization bits (16 = no quant)
|
||||||
exclude_embeddings: don't replace lm_head/embedding (usually)
|
exclude_embeddings: don't replace lm_head/embedding (usually)
|
||||||
threshold: deadzone threshold for ternary quantization (0 < t < 1)
|
threshold: deadzone threshold for ternary quantization (0 < t < 1)
|
||||||
|
soft_quant: use tanh proxy for smooth gradients (vs hard ternary)
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
@@ -248,6 +349,7 @@ def replace_linears_with_bitlinear(model, group_size=128, activation_bits=8,
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
activation_bits=activation_bits,
|
activation_bits=activation_bits,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
|
soft_quant=soft_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize from FP weights (critical for warmup to work)
|
# Initialize from FP weights (critical for warmup to work)
|
||||||
@@ -291,7 +393,8 @@ def get_quant_stats(model):
|
|||||||
|
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if isinstance(module, BitLinear):
|
if isinstance(module, BitLinear):
|
||||||
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size)
|
w_q, _ = ternary_quantize(module.weight, group_size=module.group_size,
|
||||||
|
threshold=module.threshold)
|
||||||
n = w_q.numel()
|
n = w_q.numel()
|
||||||
total += n
|
total += n
|
||||||
count_neg1 += (w_q == -1).sum().item()
|
count_neg1 += (w_q == -1).sum().item()
|
||||||
@@ -424,6 +527,7 @@ def train(args):
|
|||||||
group_size=args.group_size,
|
group_size=args.group_size,
|
||||||
activation_bits=args.activation_bits,
|
activation_bits=args.activation_bits,
|
||||||
threshold=args.threshold,
|
threshold=args.threshold,
|
||||||
|
soft_quant=args.soft_quant,
|
||||||
)
|
)
|
||||||
print(f"Replaced {n_replaced} Linear layers with BitLinear")
|
print(f"Replaced {n_replaced} Linear layers with BitLinear")
|
||||||
|
|
||||||
@@ -596,6 +700,8 @@ def main():
|
|||||||
parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps")
|
parser.add_argument("--quant-warmup-steps", type=int, default=DEFAULTS["quant_warmup_steps"], dest="quant_warmup_steps")
|
||||||
parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits")
|
parser.add_argument("--activation-bits", type=int, default=DEFAULTS["activation_bits"], dest="activation_bits")
|
||||||
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"], dest="threshold")
|
||||||
|
parser.add_argument("--soft-quant", action='store_true', default=DEFAULTS["soft_quant"], dest="soft_quant")
|
||||||
|
parser.add_argument("--no-soft-quant", action='store_false', dest="soft_quant")
|
||||||
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
parser.add_argument("--train-dataset", default=DEFAULTS["train_dataset"], dest="train_dataset")
|
||||||
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
parser.add_argument("--eval-data-path", default=DEFAULTS["eval_data_path"], dest="eval_data_path")
|
||||||
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
parser.add_argument("--log-file", default=DEFAULTS["log_file"])
|
||||||
|
|||||||
Reference in New Issue
Block a user