Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions content/course/submissions/scratch-1/jdvakil.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
---
title: "Scratch-1: jdvakil"
student: "jdvakil"
date: "2026-02-03"
---

# Scratch-1: The Transformer Backbone

## Training Loss

Trained the transformer backbone for 10 epochs on 10k discretized trajectories. Loss dropped fast in epoch 1 then slowly improved from there.
![Backbone Training Loss](../../../../src/assignments/scratch-1/training_loss.png)
Final loss was **~1.41** with perplexity **~4.09**. The curve looks stable with no weird spikes or divergence.

## Ablation Studies: RoPE vs Sinusoidal

Ran a comparison between RoPE and standard sinusoidal positional embeddings.
![Ablation Comparison](../../../../src/assignments/scratch-1/rope_vs_sinusoidal_ablation.png)

- **RoPE**: Hit loss of **1.98** (perplexity 7.25) in just 3 epochs
- **Sinusoidal**: Stuck around **~4.40** and basically didn't learn
RoPE works better here because it encodes relative positions directly in the attention computation rather than adding absolute position info to the embeddings. The sinusoidal results surprised me, expected it to at least converge somewhat, but it just sat there.

## Inference Benchmark

Tested generation speed with and without KV caching.
![Benchmark Speed](../../../../src/assignments/scratch-1/kv_cache_vs_native_benchmark.png)

- **With Cache**: 229.5 tokens/sec
- **No Cache**: 209.3 tokens/sec
- **Speedup**: ~1.10x
Not a huge difference for these short sequences, but would matter more for longer generation.

## Attention Visualization

Plotted attention maps from Layer 0 to see what the model learned.
![Attention Maps](../../../../src/assignments/scratch-1/attention_maps.png)
You can see the lower-triangular pattern from the causal mask. The heads are clearly attending to previous tokens, which is what we want for next-token prediction.

## The Audit: Removing the Causal Mask

Removed `torch.tril` to see what happens when the model can peek at future tokens.
Training loss dropped to ~3.1 (vs starting at 3.7) way faster than normal. But this is fake progress, at inference time there are no future tokens to look at, so the model is useless. It learned to copy instead of predict.

### Why the Model "Cheats"

Without the causal mask, token at position $t$ can see position $t+1$. The target for position $t$ is literally the value at $t+1$, so the model just copies it. No actual learning of dynamics happening.

## Code Highlights

Implemented RoPE with KV-caching support in `backbone.py`. Also wrote ablations for RoPE vs sinusoidal and KV cache benchmarking.
Collect data:

```
python generate_data.py --num_trajectories 10000 --seq_length 50 --output data/trajectories.pkl
```

Train:

```
python backbone.py
```

Ablations:

```
python ablations.py
```

Extra packages:

```
pip install pillow six seaborn
```

## Challenges and Solutions

**Problem**: Loss was flat at ~4.6 even though the code was correct.
Spent way too long debugging the model before I thought to check the data. Looked at `generate_data.py` and found the issue: signal was 0.01 and noise was 0.05, so SNR was terrible. Bumped signal to 0.1 and dropped noise to 0.001. Loss immediately started decreasing and converged to ~1.4.
255 changes: 255 additions & 0 deletions src/assignments/scratch-1/ablations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import torch
import torch.nn as nn
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'nn' is not used.

Suggested change
import torch.nn as nn

Copilot uses AI. Check for mistakes.
import torch.optim as optim
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import os
import math
import numpy as np
import time

from backbone import DecoderOnlyTransformer, train_epoch

PARAMS = dict(
vocab_size=256,
dim=256,
num_layers=4,
num_heads=8,
ff_hidden_dim=1024,
max_seq_len=50,
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_and_eval(name, model, train_loader, val_loader, epochs=3):
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

step_losses = []
val_losses = []
perplexities = []

for epoch in range(epochs):
model.train()
for batch in train_loader:
batch = batch.to(DEVICE)
inputs = batch[:, :-1].contiguous()
targets = batch[:, 1:].contiguous()

optimizer.zero_grad()
_, loss, _ = model(inputs, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
step_losses.append(loss.item())

# Validation
model.eval()
val_loss_sum = 0
val_batches = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(DEVICE)
inputs = batch[:, :-1].contiguous()
targets = batch[:, 1:].contiguous()
_, loss, _ = model(inputs, targets)
val_loss_sum += loss.item()
val_batches += 1

avg_val_loss = val_loss_sum / val_batches
val_losses.append(avg_val_loss)
ppl = math.exp(avg_val_loss)
perplexities.append(ppl)
print(f"{name} Epoch {epoch}, Val Loss: {avg_val_loss:.4f}, Perplexity: {ppl:.4f}")

return step_losses, val_losses, perplexities


def run_ablation():
print("RoPE vs Sinusoidal")
with open("data/trajectories.pkl", "rb") as f:
trajectories = pickle.load(f)
full_dataset = trajectories["actions"].clone().detach().long()

N = 2000
full_dataset = full_dataset[:N]
train_size = int(0.8 * N)
val_size = N - train_size
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=False)

model_rope = DecoderOnlyTransformer(**PARAMS, use_rope=True).to(DEVICE)
rope_steps, rope_val, rope_ppl = train_and_eval(
"RoPE", model_rope, train_loader, val_loader
)

model_sin = DecoderOnlyTransformer(**PARAMS, use_rope=False).to(DEVICE)
sin_steps, sin_val, sin_ppl = train_and_eval(
"Sinusoidal", model_sin, train_loader, val_loader
)

# Plotting
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable fig is not used.

Copilot uses AI. Check for mistakes.

def smooth(data, window=10):
return np.convolve(data, np.ones(window) / window, mode="valid")

ax1.plot(smooth(rope_steps), label="RoPE", alpha=0.8)
ax1.plot(smooth(sin_steps), label="Sinusoidal", alpha=0.8)
ax1.set_title("Training Loss (Step-wise)")
ax1.legend()
ax1.grid(True, alpha=0.3)

epochs = range(len(rope_val))
ax2.plot(epochs, rope_val, label="RoPE", marker="o")
ax2.plot(epochs, sin_val, label="Sinusoidal", marker="x")
ax2.set_title("Validation Loss")
ax2.legend()
ax2.grid(True, alpha=0.3)

ax3.plot(epochs, rope_ppl, label="RoPE", marker="o")
ax3.plot(epochs, sin_ppl, label="Sinusoidal", marker="x")
ax3.set_title("Validation Perplexity")
ax3.set_yscale("log")
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("rope_vs_sinusoidal_ablation.png")


def run_benchmark():
print("KV-Cache vs Native")
model = DecoderOnlyTransformer(**PARAMS).to(DEVICE)
model.eval()

input_ids = torch.randint(0, 256, (1, 10)).to(DEVICE)
max_new_tokens = 50
num_runs = 5
model.generate(input_ids, max_new_tokens=5, use_cache=True)
model.generate(input_ids, max_new_tokens=5, use_cache=False)

times_cache = []
for _ in range(num_runs):
start = time.time()
model.generate(input_ids, max_new_tokens=max_new_tokens, use_cache=True)
times_cache.append(time.time() - start)

avg_cache = sum(times_cache) / num_runs
speed_cache = max_new_tokens / avg_cache
print(f"With Cache: {speed_cache:.2f} tok/s")

times_no_cache = []
for _ in range(num_runs):
start = time.time()
model.generate(input_ids, max_new_tokens=max_new_tokens, use_cache=False)
times_no_cache.append(time.time() - start)

avg_no_cache = sum(times_no_cache) / num_runs
speed_no_cache = max_new_tokens / avg_no_cache
print(f"No Cache: {speed_no_cache:.2f} tok/s")

plt.figure()
plt.bar(
["Without Cache", "With Cache"],
[speed_no_cache, speed_cache],
color=["red", "blue"],
)
plt.ylabel("Tokens / Second")
plt.title("Inference Speed Comparison")
plt.savefig("kv_cache_vs_native_benchmark.png")


def run_audit():
print("Removing Causal Mask")

with open("data/trajectories.pkl", "rb") as f:
trajectories = pickle.load(f)
dataset = trajectories["actions"].clone().detach().long()[:1000]
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

class CheatingTransformer(DecoderOnlyTransformer):
def forward(self, input_ids, targets=None, past_kv=None, use_cache=False):
batch_size, seq_len = input_ids.shape
x = self.token_embedding(input_ids)
if not self.use_rope:
x = self.pos_embedding(x)
mask = torch.ones(seq_len, seq_len, device=x.device)

for block in self.blocks:
x, _ = block(x, mask)

x = self.norm_final(x)
logits = self.lm_head(x)

loss = None
if targets is not None:
loss = torch.nn.functional.cross_entropy(
logits.view(-1, self.vocab_size), targets.view(-1), ignore_index=-1
)
return logits, loss, None

model = CheatingTransformer(**PARAMS, use_rope=True).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

print("Training Cheating Model (Bidirectional Attention)...")
for epoch in range(1):
loss, _ = train_epoch(model, dataloader, optimizer, DEVICE, epoch)
print(f"Epoch {epoch}, Loss: {loss:.4f} (Expected << Initial Loss)")


def run_visualization():
model = DecoderOnlyTransformer(**PARAMS).to(DEVICE)

if os.path.exists("checkpoints/best_model.pt"):
try:
state = torch.load("checkpoints/best_model.pt", map_location=DEVICE)
model_keys = model.state_dict().keys()
filtered_state = {
k: v
for k, v in state.items()
if k in model_keys and v.size() == model.state_dict()[k].size()
}
model.load_state_dict(filtered_state, strict=False)
print("Loaded pretrained weights.")
except:
print("Using random weights.")
Comment on lines +217 to +218
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare except clause at line 217 catches all exceptions including KeyboardInterrupt and SystemExit, which is bad practice. The code should catch specific exceptions (e.g., RuntimeError, KeyError, FileNotFoundError) or at minimum catch Exception instead of all BaseException types. Additionally, silently falling back to random weights without informing the user why the load failed makes debugging difficult.

Suggested change
except:
print("Using random weights.")
except Exception as e:
print(f"Failed to load pretrained weights, using random weights instead. Error: {e}")

Copilot uses AI. Check for mistakes.

model.eval()
seq_len = 20
x = torch.randint(0, 256, (1, seq_len)).to(DEVICE)
layer = model.blocks[0].attention
qkv = layer.qkv_proj(model.token_embedding(x))
qkv = qkv.view(1, seq_len, 3, layer.num_heads, layer.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

if layer.use_rope:
q, k = layer.rope(q, k)

scores = (q @ k.transpose(-2, -1)) / math.sqrt(layer.head_dim)
mask = torch.tril(torch.ones(seq_len, seq_len, device=DEVICE))
scores = scores.masked_fill(mask == 0, float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)

fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for h in range(4):
sns.heatmap(
attn_weights[0, h].detach().cpu().numpy(),
ax=axes[h],
cmap="viridis",
square=True,
cbar=False,
)
axes[h].set_title(f"Head {h}")

plt.suptitle(f"Layer 0 Attention Patterns")
plt.savefig("attention_maps.png")


if __name__ == "__main__":
run_ablation()
run_benchmark()
run_audit()
run_visualization()
Binary file added src/assignments/scratch-1/attention_maps.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading