diff --git a/content/course/submissions/scratch-1/jdvakil.mdx b/content/course/submissions/scratch-1/jdvakil.mdx new file mode 100644 index 00000000..d513c560 --- /dev/null +++ b/content/course/submissions/scratch-1/jdvakil.mdx @@ -0,0 +1,79 @@ +--- +title: "Scratch-1: jdvakil" +student: "jdvakil" +date: "2026-02-03" +--- + +# Scratch-1: The Transformer Backbone + +## Training Loss + +Trained the transformer backbone for 10 epochs on 10k discretized trajectories. Loss dropped fast in epoch 1 then slowly improved from there. +![Backbone Training Loss](../../../../src/assignments/scratch-1/training_loss.png) +Final loss was **~1.41** with perplexity **~4.09**. The curve looks stable with no weird spikes or divergence. + +## Ablation Studies: RoPE vs Sinusoidal + +Ran a comparison between RoPE and standard sinusoidal positional embeddings. +![Ablation Comparison](../../../../src/assignments/scratch-1/rope_vs_sinusoidal_ablation.png) + +- **RoPE**: Hit loss of **1.98** (perplexity 7.25) in just 3 epochs +- **Sinusoidal**: Stuck around **~4.40** and basically didn't learn + RoPE works better here because it encodes relative positions directly in the attention computation rather than adding absolute position info to the embeddings. The sinusoidal results surprised me, expected it to at least converge somewhat, but it just sat there. + +## Inference Benchmark + +Tested generation speed with and without KV caching. +![Benchmark Speed](../../../../src/assignments/scratch-1/kv_cache_vs_native_benchmark.png) + +- **With Cache**: 229.5 tokens/sec +- **No Cache**: 209.3 tokens/sec +- **Speedup**: ~1.10x + Not a huge difference for these short sequences, but would matter more for longer generation. + +## Attention Visualization + +Plotted attention maps from Layer 0 to see what the model learned. +![Attention Maps](../../../../src/assignments/scratch-1/attention_maps.png) +You can see the lower-triangular pattern from the causal mask. The heads are clearly attending to previous tokens, which is what we want for next-token prediction. + +## The Audit: Removing the Causal Mask + +Removed `torch.tril` to see what happens when the model can peek at future tokens. +Training loss dropped to ~3.1 (vs starting at 3.7) way faster than normal. But this is fake progress, at inference time there are no future tokens to look at, so the model is useless. It learned to copy instead of predict. + +### Why the Model "Cheats" + +Without the causal mask, token at position $t$ can see position $t+1$. The target for position $t$ is literally the value at $t+1$, so the model just copies it. No actual learning of dynamics happening. + +## Code Highlights + +Implemented RoPE with KV-caching support in `backbone.py`. Also wrote ablations for RoPE vs sinusoidal and KV cache benchmarking. +Collect data: + +``` +python generate_data.py --num_trajectories 10000 --seq_length 50 --output data/trajectories.pkl +``` + +Train: + +``` +python backbone.py +``` + +Ablations: + +``` +python ablations.py +``` + +Extra packages: + +``` +pip install pillow six seaborn +``` + +## Challenges and Solutions + +**Problem**: Loss was flat at ~4.6 even though the code was correct. +Spent way too long debugging the model before I thought to check the data. Looked at `generate_data.py` and found the issue: signal was 0.01 and noise was 0.05, so SNR was terrible. Bumped signal to 0.1 and dropped noise to 0.001. Loss immediately started decreasing and converged to ~1.4. diff --git a/src/assignments/scratch-1/ablations.py b/src/assignments/scratch-1/ablations.py new file mode 100644 index 00000000..d720e8de --- /dev/null +++ b/src/assignments/scratch-1/ablations.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import pickle +import matplotlib.pyplot as plt +import seaborn as sns +import os +import math +import numpy as np +import time + +from backbone import DecoderOnlyTransformer, train_epoch + +PARAMS = dict( + vocab_size=256, + dim=256, + num_layers=4, + num_heads=8, + ff_hidden_dim=1024, + max_seq_len=50, +) +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def train_and_eval(name, model, train_loader, val_loader, epochs=3): + optimizer = optim.AdamW(model.parameters(), lr=1e-4) + + step_losses = [] + val_losses = [] + perplexities = [] + + for epoch in range(epochs): + model.train() + for batch in train_loader: + batch = batch.to(DEVICE) + inputs = batch[:, :-1].contiguous() + targets = batch[:, 1:].contiguous() + + optimizer.zero_grad() + _, loss, _ = model(inputs, targets) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + step_losses.append(loss.item()) + + # Validation + model.eval() + val_loss_sum = 0 + val_batches = 0 + with torch.no_grad(): + for batch in val_loader: + batch = batch.to(DEVICE) + inputs = batch[:, :-1].contiguous() + targets = batch[:, 1:].contiguous() + _, loss, _ = model(inputs, targets) + val_loss_sum += loss.item() + val_batches += 1 + + avg_val_loss = val_loss_sum / val_batches + val_losses.append(avg_val_loss) + ppl = math.exp(avg_val_loss) + perplexities.append(ppl) + print(f"{name} Epoch {epoch}, Val Loss: {avg_val_loss:.4f}, Perplexity: {ppl:.4f}") + + return step_losses, val_losses, perplexities + + +def run_ablation(): + print("RoPE vs Sinusoidal") + with open("data/trajectories.pkl", "rb") as f: + trajectories = pickle.load(f) + full_dataset = trajectories["actions"].clone().detach().long() + + N = 2000 + full_dataset = full_dataset[:N] + train_size = int(0.8 * N) + val_size = N - train_size + train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size]) + + train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) + val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=False) + + model_rope = DecoderOnlyTransformer(**PARAMS, use_rope=True).to(DEVICE) + rope_steps, rope_val, rope_ppl = train_and_eval( + "RoPE", model_rope, train_loader, val_loader + ) + + model_sin = DecoderOnlyTransformer(**PARAMS, use_rope=False).to(DEVICE) + sin_steps, sin_val, sin_ppl = train_and_eval( + "Sinusoidal", model_sin, train_loader, val_loader + ) + + # Plotting + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) + + def smooth(data, window=10): + return np.convolve(data, np.ones(window) / window, mode="valid") + + ax1.plot(smooth(rope_steps), label="RoPE", alpha=0.8) + ax1.plot(smooth(sin_steps), label="Sinusoidal", alpha=0.8) + ax1.set_title("Training Loss (Step-wise)") + ax1.legend() + ax1.grid(True, alpha=0.3) + + epochs = range(len(rope_val)) + ax2.plot(epochs, rope_val, label="RoPE", marker="o") + ax2.plot(epochs, sin_val, label="Sinusoidal", marker="x") + ax2.set_title("Validation Loss") + ax2.legend() + ax2.grid(True, alpha=0.3) + + ax3.plot(epochs, rope_ppl, label="RoPE", marker="o") + ax3.plot(epochs, sin_ppl, label="Sinusoidal", marker="x") + ax3.set_title("Validation Perplexity") + ax3.set_yscale("log") + ax3.legend() + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("rope_vs_sinusoidal_ablation.png") + + +def run_benchmark(): + print("KV-Cache vs Native") + model = DecoderOnlyTransformer(**PARAMS).to(DEVICE) + model.eval() + + input_ids = torch.randint(0, 256, (1, 10)).to(DEVICE) + max_new_tokens = 50 + num_runs = 5 + model.generate(input_ids, max_new_tokens=5, use_cache=True) + model.generate(input_ids, max_new_tokens=5, use_cache=False) + + times_cache = [] + for _ in range(num_runs): + start = time.time() + model.generate(input_ids, max_new_tokens=max_new_tokens, use_cache=True) + times_cache.append(time.time() - start) + + avg_cache = sum(times_cache) / num_runs + speed_cache = max_new_tokens / avg_cache + print(f"With Cache: {speed_cache:.2f} tok/s") + + times_no_cache = [] + for _ in range(num_runs): + start = time.time() + model.generate(input_ids, max_new_tokens=max_new_tokens, use_cache=False) + times_no_cache.append(time.time() - start) + + avg_no_cache = sum(times_no_cache) / num_runs + speed_no_cache = max_new_tokens / avg_no_cache + print(f"No Cache: {speed_no_cache:.2f} tok/s") + + plt.figure() + plt.bar( + ["Without Cache", "With Cache"], + [speed_no_cache, speed_cache], + color=["red", "blue"], + ) + plt.ylabel("Tokens / Second") + plt.title("Inference Speed Comparison") + plt.savefig("kv_cache_vs_native_benchmark.png") + + +def run_audit(): + print("Removing Causal Mask") + + with open("data/trajectories.pkl", "rb") as f: + trajectories = pickle.load(f) + dataset = trajectories["actions"].clone().detach().long()[:1000] + dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) + + class CheatingTransformer(DecoderOnlyTransformer): + def forward(self, input_ids, targets=None, past_kv=None, use_cache=False): + batch_size, seq_len = input_ids.shape + x = self.token_embedding(input_ids) + if not self.use_rope: + x = self.pos_embedding(x) + mask = torch.ones(seq_len, seq_len, device=x.device) + + for block in self.blocks: + x, _ = block(x, mask) + + x = self.norm_final(x) + logits = self.lm_head(x) + + loss = None + if targets is not None: + loss = torch.nn.functional.cross_entropy( + logits.view(-1, self.vocab_size), targets.view(-1), ignore_index=-1 + ) + return logits, loss, None + + model = CheatingTransformer(**PARAMS, use_rope=True).to(DEVICE) + optimizer = optim.AdamW(model.parameters(), lr=1e-3) + + print("Training Cheating Model (Bidirectional Attention)...") + for epoch in range(1): + loss, _ = train_epoch(model, dataloader, optimizer, DEVICE, epoch) + print(f"Epoch {epoch}, Loss: {loss:.4f} (Expected << Initial Loss)") + + +def run_visualization(): + model = DecoderOnlyTransformer(**PARAMS).to(DEVICE) + + if os.path.exists("checkpoints/best_model.pt"): + try: + state = torch.load("checkpoints/best_model.pt", map_location=DEVICE) + model_keys = model.state_dict().keys() + filtered_state = { + k: v + for k, v in state.items() + if k in model_keys and v.size() == model.state_dict()[k].size() + } + model.load_state_dict(filtered_state, strict=False) + print("Loaded pretrained weights.") + except: + print("Using random weights.") + + model.eval() + seq_len = 20 + x = torch.randint(0, 256, (1, seq_len)).to(DEVICE) + layer = model.blocks[0].attention + qkv = layer.qkv_proj(model.token_embedding(x)) + qkv = qkv.view(1, seq_len, 3, layer.num_heads, layer.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if layer.use_rope: + q, k = layer.rope(q, k) + + scores = (q @ k.transpose(-2, -1)) / math.sqrt(layer.head_dim) + mask = torch.tril(torch.ones(seq_len, seq_len, device=DEVICE)) + scores = scores.masked_fill(mask == 0, float("-inf")) + attn_weights = torch.softmax(scores, dim=-1) + + fig, axes = plt.subplots(1, 4, figsize=(20, 5)) + for h in range(4): + sns.heatmap( + attn_weights[0, h].detach().cpu().numpy(), + ax=axes[h], + cmap="viridis", + square=True, + cbar=False, + ) + axes[h].set_title(f"Head {h}") + + plt.suptitle(f"Layer 0 Attention Patterns") + plt.savefig("attention_maps.png") + + +if __name__ == "__main__": + run_ablation() + run_benchmark() + run_audit() + run_visualization() diff --git a/src/assignments/scratch-1/attention_maps.png b/src/assignments/scratch-1/attention_maps.png new file mode 100644 index 00000000..33ae7137 Binary files /dev/null and b/src/assignments/scratch-1/attention_maps.png differ diff --git a/src/assignments/scratch-1/backbone.py b/src/assignments/scratch-1/backbone.py index 247227d1..a61c875a 100644 --- a/src/assignments/scratch-1/backbone.py +++ b/src/assignments/scratch-1/backbone.py @@ -13,6 +13,10 @@ import torch.nn as nn import torch.nn.functional as F import math +import pickle +import os +import matplotlib.pyplot as plt +import numpy as np from typing import Optional, Tuple @@ -33,7 +37,7 @@ def __init__(self, dim: int, eps: float = 1e-6): self.eps = eps # TODO: Initialize learnable scale parameter 'g' (gamma) # Hint: Use nn.Parameter with torch.ones - self.scale = None # REPLACE THIS LINE + self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -48,9 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Step 3: Apply learnable scale parameter # HINT: Use torch.mean, torch.rsqrt for efficiency - # rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - raise NotImplementedError("TODO: Implement RMSNorm forward pass") + rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * rms * self.scale class RotaryPositionalEmbedding(nn.Module): @@ -89,21 +92,27 @@ def rotate_half(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) - def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, q: torch.Tensor, k: torch.Tensor, seq_start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to query and key tensors Args: q: Query tensor (batch, num_heads, seq_len, head_dim) k: Key tensor (batch, num_heads, seq_len, head_dim) + seq_start_pos: Starting position index for the sequence Returns: Rotated (q, k) tensors """ seq_len = q.shape[2] + end_pos = seq_start_pos + seq_len + + # Get cached cos/sin values for the specific positions + cos = self.cos_cached[seq_start_pos:end_pos, ...] + sin = self.sin_cached[seq_start_pos:end_pos, ...] - # Get cached cos/sin values - cos = self.cos_cached[:seq_len, ...] - sin = self.sin_cached[:seq_len, ...] + # Expand for batch and num_heads: (seq_len, head_dim) -> (1, 1, seq_len, head_dim) + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] # Apply rotation: q_rot = q * cos + rotate_half(q) * sin q_rot = (q * cos) + (self.rotate_half(q) * sin) @@ -112,6 +121,33 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch return q_rot, k_rot +class SinusoidalPositionalEmbedding(nn.Module): + """ + Standard Sinusoidal Positional Embedding + """ + + def __init__(self, dim: int, max_seq_len: int = 2048): + super().__init__() + pe = torch.zeros(max_seq_len, dim) + position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor (batch, seq_len, dim) + Returns: + x + positional_encoding + """ + seq_len = x.size(1) + return x + self.pe[:seq_len, :] + + class CausalSelfAttention(nn.Module): """ Multi-Head Causal Self-Attention with RoPE @@ -121,14 +157,17 @@ class CausalSelfAttention(nn.Module): - Uses RoPE instead of absolute positional embeddings """ - def __init__(self, dim: int, num_heads: int, dropout: float = 0.1): + def __init__( + self, dim: int, num_heads: int, dropout: float = 0.1, use_rope: bool = True + ): super().__init__() assert dim % num_heads == 0, "dim must be divisible by num_heads" self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 + self.use_rope = use_rope # Linear projections for Q, K, V self.qkv_proj = nn.Linear(dim, 3 * dim, bias=False) @@ -139,51 +178,80 @@ def __init__(self, dim: int, num_heads: int, dropout: float = 0.1): self.resid_dropout = nn.Dropout(dropout) # Rotary embeddings - self.rope = RotaryPositionalEmbedding(self.head_dim) + if self.use_rope: + self.rope = RotaryPositionalEmbedding(self.head_dim) - def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Args: x: Input tensor (batch, seq_len, dim) mask: Optional attention mask (seq_len, seq_len) + past_kv: Key/Value cache from previous step Returns: Output tensor (batch, seq_len, dim) + present_kv: New KV cache """ batch_size, seq_len, _ = x.shape - # TODO: Implement Causal Self-Attention - # Step 1: Project input to Q, K, V - # qkv = self.qkv_proj(x) # (batch, seq_len, 3*dim) - # Split into Q, K, V and reshape for multi-head attention - # Hint: Use .view() and .transpose() to get shape (batch, num_heads, seq_len, head_dim) + qkv = self.qkv_proj(x) # (batch, seq_len, 3*dim) + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + # Handle KV Cache + if past_kv is not None: + past_k, past_v = past_kv + k = torch.cat([past_k, k], dim=2) + v = torch.cat([past_v, v], dim=2) + + present_kv = (k, v) # Step 2: Apply RoPE to Q and K - # q, k = self.rope(q, k) + if self.use_rope: + # If using cache, we need to offset the position for RoPE + seq_start_pos = past_kv[0].shape[2] if past_kv is not None else 0 + q, k = self.rope(q, k, seq_start_pos=seq_start_pos) # Step 3: Compute attention scores - # scores = (Q @ K^T) / sqrt(d_k) - # Hint: Use torch.matmul or @ operator - # Shape should be (batch, num_heads, seq_len, seq_len) + # q: (batch, n_heads, seq_len_q, head_dim) + # k: (batch, n_heads, seq_len_k, head_dim) + # For cached inference: seq_len_q=1, seq_len_k=context_len + scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Step 4: Apply causal mask - # The mask should prevent position i from attending to positions > i - # Hint: Create a lower-triangular matrix using torch.tril - # Set masked positions to -inf BEFORE softmax - # Example: scores = scores.masked_fill(mask == 0, float('-inf')) + if mask is not None: + # During generation with cache: + # scores shape: (batch, heads, 1, total_seq_len) + # mask shape should align. + # If standard forward pass (no cache logic or full sequence), standard mask works. + # For 1-token generation with cache, we attend to all previous tokens (mask is all 1s effectively, usually handled by shape) + + # If mask is provided and shapes mismatch (e.g. during generation), we might need to slice + if mask.shape[-2] == seq_len and mask.shape[-1] == k.shape[2]: + scores = scores.masked_fill(mask == 0, float("-inf")) + elif ( + past_kv is None + ): # Only apply standard triangular mask during training/full forward + scores = scores.masked_fill(mask == 0, float("-inf")) # Step 5: Apply softmax and dropout - # attn_weights = F.softmax(scores, dim=-1) - # attn_weights = self.attn_dropout(attn_weights) + attn_weights = F.softmax(scores, dim=-1) + attn_weights = self.attn_dropout(attn_weights) # Step 6: Apply attention to values - # out = attn_weights @ V + out = attn_weights @ v # Step 7: Reshape and project back - # Concatenate heads and apply output projection - # Hint: Use .transpose() and .contiguous().view() to reshape + out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim) + out = self.resid_dropout(self.out_proj(out)) - raise NotImplementedError("TODO: Implement CausalSelfAttention forward pass") + return out, present_kv class FeedForward(nn.Module): @@ -223,25 +291,41 @@ class TransformerBlock(nn.Module): x = x + FeedForward(RMSNorm(x)) """ - def __init__(self, dim: int, num_heads: int, ff_hidden_dim: int, dropout: float = 0.1): + def __init__( + self, + dim: int, + num_heads: int, + ff_hidden_dim: int, + dropout: float = 0.1, + use_rope: bool = True, + ): super().__init__() - self.attention = CausalSelfAttention(dim, num_heads, dropout) + self.attention = CausalSelfAttention(dim, num_heads, dropout, use_rope=use_rope) self.feed_forward = FeedForward(dim, ff_hidden_dim, dropout) self.norm1 = RMSNorm(dim) self.norm2 = RMSNorm(dim) - def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Args: x: Input tensor (batch, seq_len, dim) mask: Optional attention mask + past_kv: Previous KV cache Returns: Output tensor (batch, seq_len, dim) + present_kv: New KV cache """ # Pre-norm architecture (norm before attention/FF) - x = x + self.attention(self.norm1(x), mask) + norm_x = self.norm1(x) + attn_out, present_kv = self.attention(norm_x, mask, past_kv) + x = x + attn_out x = x + self.feed_forward(self.norm2(x)) - return x + return x, present_kv class DecoderOnlyTransformer(nn.Module): @@ -260,20 +344,30 @@ def __init__( ff_hidden_dim: int, max_seq_len: int = 2048, dropout: float = 0.1, + use_rope: bool = True, ): super().__init__() self.vocab_size = vocab_size self.dim = dim self.max_seq_len = max_seq_len + self.use_rope = use_rope # Token embedding self.token_embedding = nn.Embedding(vocab_size, dim) + # Positional embedding (if not using RoPE) + if not use_rope: + self.pos_embedding = SinusoidalPositionalEmbedding(dim, max_seq_len) + # Transformer blocks - self.blocks = nn.ModuleList([ - TransformerBlock(dim, num_heads, ff_hidden_dim, dropout) - for _ in range(num_layers) - ]) + self.blocks = nn.ModuleList( + [ + TransformerBlock( + dim, num_heads, ff_hidden_dim, dropout, use_rope=use_rope + ) + for _ in range(num_layers) + ] + ) # Final norm and projection to vocabulary self.norm_final = RMSNorm(dim) @@ -290,31 +384,49 @@ def _init_weights(self, module): torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, SinusoidalPositionalEmbedding): + pass # No weights to initialize def forward( self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + past_kv: Optional[list] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list]]: """ Args: input_ids: Input token indices (batch, seq_len) targets: Target token indices for training (batch, seq_len) + past_kv: List of KV caches for each layer + use_cache: whether to use KV caching Returns: logits: Output logits (batch, seq_len, vocab_size) loss: Cross-entropy loss if targets provided, else None + present_kv: List of new KV caches """ batch_size, seq_len = input_ids.shape # Embed tokens x = self.token_embedding(input_ids) # (batch, seq_len, dim) + # Add positional embeddings if not using RoPE + if not self.use_rope: + x = self.pos_embedding(x) + # Create causal mask (lower triangular) - mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)) + # If using cache (generating one token), we don't need a mask usually or it's implicitly handled + mask = None + if not use_cache or past_kv is None: + mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)) # Apply transformer blocks - for block in self.blocks: - x = block(x, mask) + present_kv = [] + for i, block in enumerate(self.blocks): + layer_past = past_kv[i] if past_kv is not None else None + x, layer_present = block(x, mask, past_kv=layer_past) + if use_cache: + present_kv.append(layer_present) # Final norm and projection x = self.norm_final(x) @@ -330,7 +442,7 @@ def forward( ignore_index=-1, # Ignore padding tokens ) - return logits, loss + return logits, loss, present_kv if use_cache else None @torch.no_grad() def generate( @@ -339,6 +451,7 @@ def generate( max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, + use_cache: bool = True, ) -> torch.Tensor: """ Autoregressive generation @@ -348,21 +461,39 @@ def generate( max_new_tokens: Number of tokens to generate temperature: Sampling temperature top_k: Top-k sampling (if None, use full distribution) + use_cache: Whether to use KV caching Returns: Generated sequence (batch, seq_len + max_new_tokens) """ + past_kv = None + for _ in range(max_new_tokens): - # Crop context if too long - input_context = input_ids if input_ids.size(1) <= self.max_seq_len else input_ids[:, -self.max_seq_len:] + # Crop context if too long (only if NOT using cache or if cache refill logic handles it) + # If using cache, we only pass in the last token anyway, so context length limits are handled + # by the cache size implicitly or we essentially just run indefinitely until OOM/end. + # But RoPE has limit `max_seq_len`. + + if use_cache and past_kv is not None: + # Only pass the last token + input_context = input_ids[:, -1:] + else: + # Standard full context + input_context = ( + input_ids + if input_ids.size(1) <= self.max_seq_len + else input_ids[:, -self.max_seq_len :] + ) # Forward pass - logits, _ = self.forward(input_context) + logits, _, past_kv = self.forward( + input_context, past_kv=past_kv, use_cache=use_cache + ) logits = logits[:, -1, :] / temperature # Get last token logits # Optional top-k sampling if top_k is not None: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < values[:, [-1]]] = -float('Inf') + logits[logits < values[:, [-1]]] = -float("Inf") # Sample from distribution probs = F.softmax(logits, dim=-1) @@ -380,7 +511,7 @@ def train_epoch( optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, -) -> float: +) -> Tuple[float, list]: """ Train for one epoch @@ -396,21 +527,53 @@ def train_epoch( model.train() total_loss = 0.0 num_batches = 0 + step_losses = [] # TODO: Implement training loop # For each batch: - # 1. Move data to device - # 2. Forward pass (get logits and loss) - # 3. Backward pass - # 4. Gradient clipping (max_norm=1.0) - # 5. Optimizer step - # 6. Zero gradients - # 7. Accumulate loss - - # Hint: Use torch.nn.utils.clip_grad_norm_ for gradient clipping - # Hint: Print progress every 100 batches + for batch_idx, batch in enumerate(dataloader): + + # 1. Move data to device + sequences = batch.to(device) + inputs = sequences[:, :-1].contiguous() + targets = sequences[:, 1:].contiguous() + + # 2. Forward pass (get logits and loss) + optimizer.zero_grad() + logits, loss, _ = model(inputs, targets) + + # 3. Backward pass + loss.backward() + + # 4. Gradient clipping (max_norm=1.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # 5. Optimizer step + optimizer.step() + + # 6. Zero gradients + # (already done above before forward pass) + + # 7. Accumulate loss + total_loss += loss.item() + step_losses.append(loss.item()) + num_batches += 1 + + # Save checkpoint every 1000 steps + step = epoch * len(dataloader) + batch_idx + 1 + if step % 1000 == 0: + torch.save(model.state_dict(), f"checkpoints/step_{step}.pt") + + # Hint: Use torch.nn.utils.clip_grad_norm_ for gradient clipping + # Hint: Print progress every 100 batches + if (batch_idx + 1) % 100 == 0: + avg_loss = total_loss / num_batches + perplexity = math.exp(avg_loss) + print( + f"Epoch {epoch}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {avg_loss:.4f}, Perplexity: {perplexity:.4f}" + ) - raise NotImplementedError("TODO: Implement training loop") + return total_loss / num_batches, step_losses def main(): @@ -438,22 +601,62 @@ def main(): # TODO: Load dataset # Use the generate_data.py script to create synthetic trajectories # Load from data/trajectories.pkl + with open("data/trajectories.pkl", "rb") as f: + trajectories = pickle.load(f) + dataset = trajectories["actions"].clone().detach().long() + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True + ) # TODO: Create model # model = DecoderOnlyTransformer(...) + model = DecoderOnlyTransformer( + vocab_size=vocab_size, + dim=dim, + num_layers=num_layers, + num_heads=num_heads, + ff_hidden_dim=ff_hidden_dim, + max_seq_len=max_seq_len, + ).to(device) # TODO: Create optimizer # optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # TODO: Training loop # for epoch in range(num_epochs): # train_loss = train_epoch(model, train_loader, optimizer, device, epoch) # print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}") + os.makedirs("checkpoints", exist_ok=True) + all_losses = [] + for epoch in range(num_epochs): + train_loss, step_losses = train_epoch( + model, dataloader, optimizer, device, epoch + ) + all_losses.extend(step_losses) + perplexity = math.exp(train_loss) + print( + f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Perplexity: {perplexity:.4f}" + ) + torch.save(model.state_dict(), f"checkpoints/epoch_{epoch+1}.pt") + + # Save losses for plotting + with open("checkpoints/losses.pkl", "wb") as f: + pickle.dump(all_losses, f) # TODO: Save checkpoint # torch.save(model.state_dict(), "checkpoints/best_model.pt") - - print("TODO: Complete the main training script") + torch.save(model.state_dict(), "checkpoints/best_model.pt") + print("Training complete") + + # Plot training loss + plt.figure(figsize=(10, 6)) + plt.plot(all_losses) + plt.title("Training Loss") + plt.xlabel("Step") + plt.ylabel("Loss") + plt.grid(True, alpha=0.3) + plt.savefig("training_loss.png") if __name__ == "__main__": diff --git a/src/assignments/scratch-1/data/trajectories.pkl b/src/assignments/scratch-1/data/trajectories.pkl new file mode 100644 index 00000000..0502ce4d Binary files /dev/null and b/src/assignments/scratch-1/data/trajectories.pkl differ diff --git a/src/assignments/scratch-1/kv_cache_vs_native_benchmark.png b/src/assignments/scratch-1/kv_cache_vs_native_benchmark.png new file mode 100644 index 00000000..1ae7f563 Binary files /dev/null and b/src/assignments/scratch-1/kv_cache_vs_native_benchmark.png differ diff --git a/src/assignments/scratch-1/rope_vs_sinusoidal_ablation.png b/src/assignments/scratch-1/rope_vs_sinusoidal_ablation.png new file mode 100644 index 00000000..95ab4a03 Binary files /dev/null and b/src/assignments/scratch-1/rope_vs_sinusoidal_ablation.png differ diff --git a/src/assignments/scratch-1/training_loss.png b/src/assignments/scratch-1/training_loss.png new file mode 100644 index 00000000..4ea446cd Binary files /dev/null and b/src/assignments/scratch-1/training_loss.png differ