feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
This commit is contained in:
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Transformer Layer with KV-Cache Integration
|
||||
|
||||
Implements a complete decoder transformer layer that:
|
||||
- Computes Q, K, V projections
|
||||
- Stores K, V in the cache
|
||||
- Performs cached attention
|
||||
- Applies MLP with residual connections and layer norm
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple, List
|
||||
from kv_cache import KVCache, CacheConfig, BatchedKVCache
|
||||
from attention import (
|
||||
cached_attention,
|
||||
cached_attention_with_mask,
|
||||
prompt_attention,
|
||||
)
|
||||
|
||||
|
||||
class Linear:
|
||||
"""Simple linear layer (no framework)."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
dtype=np.float32, seed: int = None):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
# Kaiming initialization
|
||||
scale = np.sqrt(2.0 / in_features)
|
||||
self.weight = np.random.randn(out_features, in_features).astype(dtype) * scale
|
||||
self.bias = np.zeros(out_features, dtype=dtype)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
return (x @ self.weight.T + self.bias).astype(self.dtype)
|
||||
|
||||
|
||||
class LayerNorm:
|
||||
"""Layer normalization."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5, dtype=np.float32):
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = np.ones(dim, dtype=dtype)
|
||||
self.bias = np.zeros(dim, dtype=dtype)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
x_f = x.astype(np.float32)
|
||||
mean = np.mean(x_f, axis=-1, keepdims=True)
|
||||
var = np.var(x_f, axis=-1, keepdims=True)
|
||||
x_norm = (x_f - mean) / np.sqrt(var + self.eps)
|
||||
return (x_norm * self.weight + self.bias).astype(self.dtype)
|
||||
|
||||
|
||||
class MLP:
|
||||
"""Feed-forward network: linear -> activation -> linear."""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, dtype=np.float32, seed: int = None):
|
||||
self.fc1 = Linear(dim, hidden_dim, dtype=dtype, seed=seed)
|
||||
self.fc2 = Linear(hidden_dim, dim, dtype=dtype, seed=seed + 1 if seed else None)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
h = self.fc1.forward(x)
|
||||
# GELU approximation
|
||||
h = h * (1 + np.tanh(np.sqrt(2 / np.pi) * (h + 0.044715 * h ** 3))) * 0.5
|
||||
return self.fc2.forward(h)
|
||||
|
||||
|
||||
class TransformerDecoderLayer:
|
||||
"""
|
||||
Single decoder transformer layer with KV-cache support.
|
||||
|
||||
Architecture:
|
||||
x -> LayerNorm -> Self-Attention -> Residual -> LayerNorm -> MLP -> Residual
|
||||
|
||||
Pre-norm variant (used by most modern models).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, mlp_hidden: int,
|
||||
dtype=np.float32, seed: int = None):
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = 1.0 / np.sqrt(self.head_dim)
|
||||
self.dtype = dtype
|
||||
|
||||
# Q, K, V projections
|
||||
self.wq = Linear(dim, dim, dtype=dtype, seed=seed)
|
||||
self.wk = Linear(dim, dim, dtype=dtype, seed=seed + 1 if seed else None)
|
||||
self.wv = Linear(dim, dim, dtype=dtype, seed=seed + 2 if seed else None)
|
||||
|
||||
# Output projection
|
||||
self.wo = Linear(dim, dim, dtype=dtype, seed=seed + 3 if seed else None)
|
||||
|
||||
# Normalizations
|
||||
self.norm1 = LayerNorm(dim, dtype=dtype)
|
||||
self.norm2 = LayerNorm(dim, dtype=dtype)
|
||||
|
||||
# MLP
|
||||
self.mlp = MLP(dim, mlp_hidden, dtype=dtype, seed=seed + 4 if seed else None)
|
||||
|
||||
def _to_heads(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reshape (batch, seq, dim) -> (batch, seq, heads, head_dim)."""
|
||||
batch, seq, _ = x.shape
|
||||
return x.reshape(batch, seq, self.num_heads, self.head_dim)
|
||||
|
||||
def _from_heads(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reshape (batch, seq, heads, head_dim) -> (batch, seq, dim)."""
|
||||
batch, seq, _, _ = x.shape
|
||||
return x.reshape(batch, seq, self.dim)
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
x: np.ndarray,
|
||||
cache: KVCache,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Process the full prompt (prefill phase).
|
||||
|
||||
Args:
|
||||
x: (batch, prompt_len, dim)
|
||||
cache: KVCache to populate with K, V
|
||||
lengths: optional per-batch-item prompt lengths
|
||||
|
||||
Returns:
|
||||
output: (batch, prompt_len, dim)
|
||||
"""
|
||||
batch, seq_len, _ = x.shape
|
||||
|
||||
# Self-attention with residual
|
||||
residual = x
|
||||
x_norm = self.norm1.forward(x)
|
||||
|
||||
# Project to Q, K, V
|
||||
q = self.wq.forward(x_norm) # (batch, seq, dim)
|
||||
k = self.wk.forward(x_norm)
|
||||
v = self.wv.forward(x_norm)
|
||||
|
||||
# Reshape to multi-head
|
||||
q = self._to_heads(q).transpose(0, 2, 1, 3) # (batch, heads, seq, head_dim)
|
||||
k = self._to_heads(k).transpose(0, 2, 1, 3)
|
||||
v = self._to_heads(v).transpose(0, 2, 1, 3)
|
||||
|
||||
# Cached attention (stores K, V in cache)
|
||||
attn_out, _, _ = prompt_attention(
|
||||
q, k, v, cache, self.scale, lengths=lengths
|
||||
)
|
||||
# (batch, heads, seq, head_dim)
|
||||
|
||||
# Reshape and project output
|
||||
attn_out = attn_out.transpose(0, 2, 1, 3) # (batch, seq, heads, head_dim)
|
||||
attn_out = self._from_heads(attn_out) # (batch, seq, dim)
|
||||
attn_out = self.wo.forward(attn_out)
|
||||
|
||||
x = residual + attn_out
|
||||
|
||||
# MLP with residual
|
||||
residual = x
|
||||
x_norm = self.norm2.forward(x)
|
||||
mlp_out = self.mlp.forward(x_norm)
|
||||
x = residual + mlp_out
|
||||
|
||||
return x
|
||||
|
||||
def forward_generate(
|
||||
self,
|
||||
x: np.ndarray,
|
||||
cache: KVCache,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Process one token (generation phase).
|
||||
|
||||
Args:
|
||||
x: (batch, 1, dim) — single token
|
||||
cache: KVCache with previous K, V
|
||||
lengths: optional per-batch-item sequence lengths
|
||||
|
||||
Returns:
|
||||
output: (batch, 1, dim)
|
||||
"""
|
||||
# Self-attention with residual
|
||||
residual = x
|
||||
x_norm = self.norm1.forward(x)
|
||||
|
||||
# Project to Q, K, V
|
||||
q = self.wq.forward(x_norm) # (batch, 1, dim)
|
||||
k = self.wk.forward(x_norm)
|
||||
v = self.wv.forward(x_norm)
|
||||
|
||||
# Reshape to multi-head
|
||||
q = self._to_heads(q).transpose(0, 2, 1, 3) # (batch, heads, 1, head_dim)
|
||||
k = self._to_heads(k).transpose(0, 2, 1, 3)
|
||||
v = self._to_heads(v).transpose(0, 2, 1, 3)
|
||||
|
||||
# Store K, V in cache
|
||||
cache.update(k, v)
|
||||
|
||||
# Cached attention
|
||||
if lengths is not None:
|
||||
attn_out = cached_attention_with_mask(
|
||||
q, cache, self.scale, lengths=lengths
|
||||
)
|
||||
else:
|
||||
attn_out = cached_attention(q, cache, self.scale)
|
||||
# (batch, heads, 1, head_dim)
|
||||
|
||||
# Reshape and project output
|
||||
attn_out = attn_out.transpose(0, 2, 1, 3) # (batch, 1, heads, head_dim)
|
||||
attn_out = self._from_heads(attn_out) # (batch, 1, dim)
|
||||
attn_out = self.wo.forward(attn_out)
|
||||
|
||||
x = residual + attn_out
|
||||
|
||||
# MLP with residual
|
||||
residual = x
|
||||
x_norm = self.norm2.forward(x)
|
||||
mlp_out = self.mlp.forward(x_norm)
|
||||
x = residual + mlp_out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder:
|
||||
"""
|
||||
Full transformer decoder with KV-cache management.
|
||||
|
||||
Orchestrates prefill and generation across all layers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, dim: int, num_heads: int,
|
||||
mlp_hidden: int, vocab_size: int, max_seq_len: int,
|
||||
batch_size: int = 1, dtype=np.float32, seed: int = 42):
|
||||
self.num_layers = num_layers
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.vocab_size = vocab_size
|
||||
self.dtype = dtype
|
||||
|
||||
# Embedding
|
||||
self.embedding = np.random.randn(vocab_size, dim).astype(dtype) * 0.02
|
||||
|
||||
# Positional encoding (learnable)
|
||||
self.pos_embedding = np.random.randn(max_seq_len, dim).astype(dtype) * 0.02
|
||||
|
||||
# Layers
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(dim, num_heads, mlp_hidden,
|
||||
dtype=dtype, seed=seed + i * 100)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Final normalization and LM head
|
||||
self.final_norm = LayerNorm(dim, dtype=dtype)
|
||||
self.lm_head_weight = self.embedding.T # weight tying
|
||||
|
||||
# KV cache
|
||||
cache_config = CacheConfig(
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=self.head_dim,
|
||||
max_seq_len=max_seq_len,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.cache = BatchedKVCache(num_layers, cache_config)
|
||||
|
||||
def _add_positional_encoding(self, x: np.ndarray, start_pos: int = 0) -> np.ndarray:
|
||||
"""Add positional encoding to input embeddings."""
|
||||
batch, seq, _ = x.shape
|
||||
pos_enc = self.pos_embedding[start_pos:start_pos + seq]
|
||||
return (x + pos_enc[None, :, :]).astype(self.dtype)
|
||||
|
||||
def prefill(self, token_ids: np.ndarray,
|
||||
lengths: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
"""
|
||||
Process the full prompt.
|
||||
|
||||
Args:
|
||||
token_ids: (batch, prompt_len) integer token IDs
|
||||
lengths: optional (batch,) actual lengths per batch item
|
||||
|
||||
Returns:
|
||||
hidden: (batch, prompt_len, dim) — hidden states after all layers
|
||||
"""
|
||||
batch, prompt_len = token_ids.shape
|
||||
|
||||
# Embed + positional encoding
|
||||
x = self.embedding[token_ids] # (batch, prompt_len, dim)
|
||||
x = self._add_positional_encoding(x, start_pos=0)
|
||||
|
||||
# Through all layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer.forward_prefill(x, self.cache.caches[i], lengths=lengths)
|
||||
|
||||
return x
|
||||
|
||||
def generate_step(
|
||||
self,
|
||||
token_ids: np.ndarray,
|
||||
lengths: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate one token.
|
||||
|
||||
Args:
|
||||
token_ids: (batch, 1) — the token to process
|
||||
lengths: optional (batch,) current sequence lengths
|
||||
|
||||
Returns:
|
||||
logits: (batch, vocab_size) — output logits for next token
|
||||
"""
|
||||
batch = token_ids.shape[0]
|
||||
current_pos = self.cache.caches[0].write_pos - 1 # position of this token
|
||||
|
||||
# Embed + positional encoding
|
||||
x = self.embedding[token_ids] # (batch, 1, dim)
|
||||
x = self._add_positional_encoding(x, start_pos=current_pos)
|
||||
|
||||
# Through all layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer.forward_generate(x, self.cache.caches[i], lengths=lengths)
|
||||
|
||||
# Final norm + LM head
|
||||
x = self.final_norm.forward(x) # (batch, 1, dim)
|
||||
logits = x @ self.lm_head_weight # (batch, 1, vocab_size)
|
||||
return logits[:, 0, :] # (batch, vocab_size)
|
||||
|
||||
def generate(self, prompt_ids: np.ndarray, num_tokens: int,
|
||||
temperature: float = 1.0, top_k: int = None,
|
||||
lengths: Optional[np.ndarray] = None) -> List[int]:
|
||||
"""
|
||||
Full generation loop.
|
||||
|
||||
Args:
|
||||
prompt_ids: (batch, prompt_len) prompt token IDs
|
||||
num_tokens: number of tokens to generate
|
||||
temperature: sampling temperature
|
||||
top_k: top-k sampling
|
||||
lengths: optional per-batch-item prompt lengths
|
||||
|
||||
Returns:
|
||||
generated_ids: list of (batch,) token arrays
|
||||
"""
|
||||
# Reset cache
|
||||
self.cache.reset()
|
||||
|
||||
# Prefill
|
||||
self.prefill(prompt_ids, lengths=lengths)
|
||||
|
||||
# Get last token from prefill
|
||||
batch = prompt_ids.shape[0]
|
||||
last_tokens = prompt_ids[:, -1:] # (batch, 1)
|
||||
|
||||
# Track current lengths (start from prompt lengths)
|
||||
if lengths is not None:
|
||||
cur_lengths = lengths.copy()
|
||||
else:
|
||||
cur_lengths = np.full(batch, prompt_ids.shape[1], dtype=np.int32)
|
||||
|
||||
generated = []
|
||||
for step in range(num_tokens):
|
||||
logits = self.generate_step(last_tokens, lengths=cur_lengths)
|
||||
|
||||
# Apply temperature
|
||||
logits = logits / temperature
|
||||
|
||||
# Top-k filtering
|
||||
if top_k is not None:
|
||||
top_k_values = np.sort(logits, axis=-1)[:, -top_k:]
|
||||
mask = logits < top_k_values[:, -1:]
|
||||
logits = np.where(mask, -np.inf, logits)
|
||||
|
||||
# Softmax + sample
|
||||
probs = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
|
||||
probs = probs / np.sum(probs, axis=-1, keepdims=True)
|
||||
|
||||
# Sample
|
||||
sampled = np.array([
|
||||
np.random.choice(len(probs[b]), p=probs[b] / probs[b].sum())
|
||||
for b in range(batch)
|
||||
])
|
||||
|
||||
generated.append(sampled)
|
||||
last_tokens = sampled[:, None] # (batch, 1)
|
||||
|
||||
# Update lengths
|
||||
cur_lengths = cur_lengths + 1
|
||||
|
||||
return generated
|
||||
|
||||
def memory_report(self) -> dict:
|
||||
"""Get memory usage report."""
|
||||
return self.cache.memory_report()
|
||||
Reference in New Issue
Block a user