Conversation
There was a problem hiding this comment.
Pull request overview
This pull request implements a decoder-only Transformer from scratch for robot trajectory prediction. The student (jdvakil) has completed all TODO sections in the backbone file, added KV-caching support for efficient inference, and conducted comprehensive ablation studies comparing RoPE vs sinusoidal positional embeddings.
Changes:
- Implemented RMSNorm, CausalSelfAttention with KV-caching, and the training loop in
backbone.py - Added SinusoidalPositionalEmbedding class as an alternative to RoPE for ablation experiments
- Created
ablations.pywith experiments comparing positional embeddings, KV-cache benchmarking, causal mask audit, and attention visualization - Generated training loss plots, ablation comparison plots, KV-cache benchmark results, and attention heatmaps
- Documented findings in a submission writeup with clear explanations of results
Reviewed changes
Copilot reviewed 3 out of 8 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
| src/assignments/scratch-1/backbone.py | Completed TODOs: RMSNorm forward pass, CausalSelfAttention with KV-cache support, training loop. Added SinusoidalPositionalEmbedding and KV-cache propagation through transformer blocks |
| src/assignments/scratch-1/ablations.py | New file implementing RoPE vs Sinusoidal ablation study, KV-cache speed benchmark, causal mask audit experiment, and attention map visualization |
| content/course/submissions/scratch-1/jdvakil.mdx | Student submission writeup documenting training results, ablation findings, and insights |
| src/assignments/scratch-1/*.png | Generated visualization outputs: training loss curve, ablation comparison, benchmark results, and attention maps |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # During generation with cache: | ||
| # scores shape: (batch, heads, 1, total_seq_len) | ||
| # mask shape should align. | ||
| # If standard forward pass (no cache logic or full sequence), standard mask works. | ||
| # For 1-token generation with cache, we attend to all previous tokens (mask is all 1s effectively, usually handled by shape) | ||
|
|
||
| # If mask is provided and shapes mismatch (e.g. during generation), we might need to slice | ||
| if mask.shape[-2] == seq_len and mask.shape[-1] == k.shape[2]: | ||
| scores = scores.masked_fill(mask == 0, float("-inf")) | ||
| elif ( | ||
| past_kv is None | ||
| ): # Only apply standard triangular mask during training/full forward | ||
| scores = scores.masked_fill(mask == 0, float("-inf")) | ||
|
|
There was a problem hiding this comment.
The mask handling logic at lines 236-241 has a subtle issue. When past_kv is not None but mask shape matches (line 236 condition), the mask will be applied. However, during cached generation with seq_len=1, the mask will be shape (1, 1) or the original full mask, and k.shape[2] will be the full cached sequence length. The condition at line 236 will likely fail in the cached case, falling through to line 238-241 which only applies mask when past_kv is None. This means during cached generation, no mask is applied at all, which could allow the model to attend to all positions. While this might work because the query is only 1 token attending to all previous tokens (which is valid), the logic is convoluted and unclear.
| # During generation with cache: | |
| # scores shape: (batch, heads, 1, total_seq_len) | |
| # mask shape should align. | |
| # If standard forward pass (no cache logic or full sequence), standard mask works. | |
| # For 1-token generation with cache, we attend to all previous tokens (mask is all 1s effectively, usually handled by shape) | |
| # If mask is provided and shapes mismatch (e.g. during generation), we might need to slice | |
| if mask.shape[-2] == seq_len and mask.shape[-1] == k.shape[2]: | |
| scores = scores.masked_fill(mask == 0, float("-inf")) | |
| elif ( | |
| past_kv is None | |
| ): # Only apply standard triangular mask during training/full forward | |
| scores = scores.masked_fill(mask == 0, float("-inf")) | |
| # Align mask's last two dimensions with scores (seq_len_q, seq_len_k). | |
| # This supports both full-sequence masks and cached generation where seq_len_q=1. | |
| if ( | |
| mask.shape[-2] != scores.shape[-2] | |
| or mask.shape[-1] != scores.shape[-1] | |
| ): | |
| # Slice from the end to match the current query/key window. | |
| mask = mask[..., -scores.shape[-2]:, -scores.shape[-1]:] | |
| scores = scores.masked_fill(mask == 0, float("-inf")) |
| except: | ||
| print("Using random weights.") |
There was a problem hiding this comment.
Bare except clause at line 217 catches all exceptions including KeyboardInterrupt and SystemExit, which is bad practice. The code should catch specific exceptions (e.g., RuntimeError, KeyError, FileNotFoundError) or at minimum catch Exception instead of all BaseException types. Additionally, silently falling back to random weights without informing the user why the load failed makes debugging difficult.
| except: | |
| print("Using random weights.") | |
| except Exception as e: | |
| print(f"Failed to load pretrained weights, using random weights instead. Error: {e}") |
| return torch.cat([-x2, x1], dim=-1) | ||
|
|
||
| def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| def forward(self, q: torch.Tensor, k: torch.Tensor, seq_start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
The seq_start_pos parameter is added to the forward signature but past_kv is used to compute it in CausalSelfAttention. This creates a coupling issue where seq_start_pos must be manually computed from past_kv. Consider removing seq_start_pos as a parameter and computing it internally within the RoPE forward method by accepting an optional kv_seq_len parameter, or handling it at the attention layer level only.
| optimizer.zero_grad() | ||
| logits, loss, _ = model(inputs, targets) | ||
|
|
||
| # 3. Backward pass | ||
| loss.backward() | ||
|
|
||
| # 4. Gradient clipping (max_norm=1.0) | ||
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | ||
|
|
||
| # 5. Optimizer step | ||
| optimizer.step() | ||
|
|
||
| # 6. Zero gradients | ||
| # (already done above before forward pass) |
There was a problem hiding this comment.
The optimizer.zero_grad() is called at line 542 before the forward pass, but then a comment at line 554-555 says it's already done. According to PyTorch best practices and the hint in the TODO comment, zero_grad should be called after optimizer.step(), not before the forward pass. While functionally equivalent in this simple loop, the current placement is unconventional and the redundant comment is confusing.
| # Save checkpoint every 1000 steps | ||
| step = epoch * len(dataloader) + batch_idx + 1 | ||
| if step % 1000 == 0: | ||
| torch.save(model.state_dict(), f"checkpoints/step_{step}.pt") |
There was a problem hiding this comment.
The checkpoint directory is created, but there's no check to ensure it exists before saving checkpoints at line 565. If the directory creation fails silently or if the code is run from a different working directory, the checkpoint save will fail. Consider adding error handling or at least checking the directory exists right before saving.
| torch.save(model.state_dict(), f"checkpoints/step_{step}.pt") | |
| checkpoint_dir = "checkpoints" | |
| checkpoint_path = os.path.join(checkpoint_dir, f"step_{step}.pt") | |
| try: | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| torch.save(model.state_dict(), checkpoint_path) | |
| except OSError as e: | |
| print(f"Warning: failed to save checkpoint to '{checkpoint_path}': {e}") |
| ) | ||
|
|
||
| # Plotting | ||
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) |
There was a problem hiding this comment.
Variable fig is not used.
| @@ -0,0 +1,255 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
There was a problem hiding this comment.
Import of 'nn' is not used.
| import torch.nn as nn |
| import pickle | ||
| import os | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np |
There was a problem hiding this comment.
Import of 'np' is not used.
| import numpy as np |
Completing all the TODOs in the backbone file and extending the capabilities to run ablations on RoPE vs sinusoidal positional embeddings, KV-caching, and plot the attention maps.