-
Notifications
You must be signed in to change notification settings - Fork 2
Scratch-1: jdvakil #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: staging
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| --- | ||
| title: "Scratch-1: jdvakil" | ||
| student: "jdvakil" | ||
| date: "2026-02-03" | ||
| --- | ||
|
|
||
| # Scratch-1: The Transformer Backbone | ||
|
|
||
| ## Training Loss | ||
|
|
||
| Trained the transformer backbone for 10 epochs on 10k discretized trajectories. Loss dropped fast in epoch 1 then slowly improved from there. | ||
|  | ||
| Final loss was **~1.41** with perplexity **~4.09**. The curve looks stable with no weird spikes or divergence. | ||
|
|
||
| ## Ablation Studies: RoPE vs Sinusoidal | ||
|
|
||
| Ran a comparison between RoPE and standard sinusoidal positional embeddings. | ||
|  | ||
|
|
||
| - **RoPE**: Hit loss of **1.98** (perplexity 7.25) in just 3 epochs | ||
| - **Sinusoidal**: Stuck around **~4.40** and basically didn't learn | ||
| RoPE works better here because it encodes relative positions directly in the attention computation rather than adding absolute position info to the embeddings. The sinusoidal results surprised me, expected it to at least converge somewhat, but it just sat there. | ||
|
|
||
| ## Inference Benchmark | ||
|
|
||
| Tested generation speed with and without KV caching. | ||
|  | ||
|
|
||
| - **With Cache**: 229.5 tokens/sec | ||
| - **No Cache**: 209.3 tokens/sec | ||
| - **Speedup**: ~1.10x | ||
| Not a huge difference for these short sequences, but would matter more for longer generation. | ||
|
|
||
| ## Attention Visualization | ||
|
|
||
| Plotted attention maps from Layer 0 to see what the model learned. | ||
|  | ||
| You can see the lower-triangular pattern from the causal mask. The heads are clearly attending to previous tokens, which is what we want for next-token prediction. | ||
|
|
||
| ## The Audit: Removing the Causal Mask | ||
|
|
||
| Removed `torch.tril` to see what happens when the model can peek at future tokens. | ||
| Training loss dropped to ~3.1 (vs starting at 3.7) way faster than normal. But this is fake progress, at inference time there are no future tokens to look at, so the model is useless. It learned to copy instead of predict. | ||
|
|
||
| ### Why the Model "Cheats" | ||
|
|
||
| Without the causal mask, token at position $t$ can see position $t+1$. The target for position $t$ is literally the value at $t+1$, so the model just copies it. No actual learning of dynamics happening. | ||
|
|
||
| ## Code Highlights | ||
|
|
||
| Implemented RoPE with KV-caching support in `backbone.py`. Also wrote ablations for RoPE vs sinusoidal and KV cache benchmarking. | ||
| Collect data: | ||
|
|
||
| ``` | ||
| python generate_data.py --num_trajectories 10000 --seq_length 50 --output data/trajectories.pkl | ||
| ``` | ||
|
|
||
| Train: | ||
|
|
||
| ``` | ||
| python backbone.py | ||
| ``` | ||
|
|
||
| Ablations: | ||
|
|
||
| ``` | ||
| python ablations.py | ||
| ``` | ||
|
|
||
| Extra packages: | ||
|
|
||
| ``` | ||
| pip install pillow six seaborn | ||
| ``` | ||
|
|
||
| ## Challenges and Solutions | ||
|
|
||
| **Problem**: Loss was flat at ~4.6 even though the code was correct. | ||
| Spent way too long debugging the model before I thought to check the data. Looked at `generate_data.py` and found the issue: signal was 0.01 and noise was 0.05, so SNR was terrible. Bumped signal to 0.1 and dropped noise to 0.001. Loss immediately started decreasing and converged to ~1.4. |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,255 @@ | ||||||||||
| import torch | ||||||||||
| import torch.nn as nn | ||||||||||
| import torch.optim as optim | ||||||||||
| import pickle | ||||||||||
| import matplotlib.pyplot as plt | ||||||||||
| import seaborn as sns | ||||||||||
| import os | ||||||||||
| import math | ||||||||||
| import numpy as np | ||||||||||
| import time | ||||||||||
|
|
||||||||||
| from backbone import DecoderOnlyTransformer, train_epoch | ||||||||||
|
|
||||||||||
| PARAMS = dict( | ||||||||||
| vocab_size=256, | ||||||||||
| dim=256, | ||||||||||
| num_layers=4, | ||||||||||
| num_heads=8, | ||||||||||
| ff_hidden_dim=1024, | ||||||||||
| max_seq_len=50, | ||||||||||
| ) | ||||||||||
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def train_and_eval(name, model, train_loader, val_loader, epochs=3): | ||||||||||
| optimizer = optim.AdamW(model.parameters(), lr=1e-4) | ||||||||||
|
|
||||||||||
| step_losses = [] | ||||||||||
| val_losses = [] | ||||||||||
| perplexities = [] | ||||||||||
|
|
||||||||||
| for epoch in range(epochs): | ||||||||||
| model.train() | ||||||||||
| for batch in train_loader: | ||||||||||
| batch = batch.to(DEVICE) | ||||||||||
| inputs = batch[:, :-1].contiguous() | ||||||||||
| targets = batch[:, 1:].contiguous() | ||||||||||
|
|
||||||||||
| optimizer.zero_grad() | ||||||||||
| _, loss, _ = model(inputs, targets) | ||||||||||
| loss.backward() | ||||||||||
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | ||||||||||
| optimizer.step() | ||||||||||
| step_losses.append(loss.item()) | ||||||||||
|
|
||||||||||
| # Validation | ||||||||||
| model.eval() | ||||||||||
| val_loss_sum = 0 | ||||||||||
| val_batches = 0 | ||||||||||
| with torch.no_grad(): | ||||||||||
| for batch in val_loader: | ||||||||||
| batch = batch.to(DEVICE) | ||||||||||
| inputs = batch[:, :-1].contiguous() | ||||||||||
| targets = batch[:, 1:].contiguous() | ||||||||||
| _, loss, _ = model(inputs, targets) | ||||||||||
| val_loss_sum += loss.item() | ||||||||||
| val_batches += 1 | ||||||||||
|
|
||||||||||
| avg_val_loss = val_loss_sum / val_batches | ||||||||||
| val_losses.append(avg_val_loss) | ||||||||||
| ppl = math.exp(avg_val_loss) | ||||||||||
| perplexities.append(ppl) | ||||||||||
| print(f"{name} Epoch {epoch}, Val Loss: {avg_val_loss:.4f}, Perplexity: {ppl:.4f}") | ||||||||||
|
|
||||||||||
| return step_losses, val_losses, perplexities | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def run_ablation(): | ||||||||||
| print("RoPE vs Sinusoidal") | ||||||||||
| with open("data/trajectories.pkl", "rb") as f: | ||||||||||
| trajectories = pickle.load(f) | ||||||||||
| full_dataset = trajectories["actions"].clone().detach().long() | ||||||||||
|
|
||||||||||
| N = 2000 | ||||||||||
| full_dataset = full_dataset[:N] | ||||||||||
| train_size = int(0.8 * N) | ||||||||||
| val_size = N - train_size | ||||||||||
| train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size]) | ||||||||||
|
|
||||||||||
| train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) | ||||||||||
| val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=False) | ||||||||||
|
|
||||||||||
| model_rope = DecoderOnlyTransformer(**PARAMS, use_rope=True).to(DEVICE) | ||||||||||
| rope_steps, rope_val, rope_ppl = train_and_eval( | ||||||||||
| "RoPE", model_rope, train_loader, val_loader | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| model_sin = DecoderOnlyTransformer(**PARAMS, use_rope=False).to(DEVICE) | ||||||||||
| sin_steps, sin_val, sin_ppl = train_and_eval( | ||||||||||
| "Sinusoidal", model_sin, train_loader, val_loader | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Plotting | ||||||||||
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) | ||||||||||
|
||||||||||
|
|
||||||||||
| def smooth(data, window=10): | ||||||||||
| return np.convolve(data, np.ones(window) / window, mode="valid") | ||||||||||
|
|
||||||||||
| ax1.plot(smooth(rope_steps), label="RoPE", alpha=0.8) | ||||||||||
| ax1.plot(smooth(sin_steps), label="Sinusoidal", alpha=0.8) | ||||||||||
| ax1.set_title("Training Loss (Step-wise)") | ||||||||||
| ax1.legend() | ||||||||||
| ax1.grid(True, alpha=0.3) | ||||||||||
|
|
||||||||||
| epochs = range(len(rope_val)) | ||||||||||
| ax2.plot(epochs, rope_val, label="RoPE", marker="o") | ||||||||||
| ax2.plot(epochs, sin_val, label="Sinusoidal", marker="x") | ||||||||||
| ax2.set_title("Validation Loss") | ||||||||||
| ax2.legend() | ||||||||||
| ax2.grid(True, alpha=0.3) | ||||||||||
|
|
||||||||||
| ax3.plot(epochs, rope_ppl, label="RoPE", marker="o") | ||||||||||
| ax3.plot(epochs, sin_ppl, label="Sinusoidal", marker="x") | ||||||||||
| ax3.set_title("Validation Perplexity") | ||||||||||
| ax3.set_yscale("log") | ||||||||||
| ax3.legend() | ||||||||||
| ax3.grid(True, alpha=0.3) | ||||||||||
|
|
||||||||||
| plt.tight_layout() | ||||||||||
| plt.savefig("rope_vs_sinusoidal_ablation.png") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def run_benchmark(): | ||||||||||
| print("KV-Cache vs Native") | ||||||||||
| model = DecoderOnlyTransformer(**PARAMS).to(DEVICE) | ||||||||||
| model.eval() | ||||||||||
|
|
||||||||||
| input_ids = torch.randint(0, 256, (1, 10)).to(DEVICE) | ||||||||||
| max_new_tokens = 50 | ||||||||||
| num_runs = 5 | ||||||||||
| model.generate(input_ids, max_new_tokens=5, use_cache=True) | ||||||||||
| model.generate(input_ids, max_new_tokens=5, use_cache=False) | ||||||||||
|
|
||||||||||
| times_cache = [] | ||||||||||
| for _ in range(num_runs): | ||||||||||
| start = time.time() | ||||||||||
| model.generate(input_ids, max_new_tokens=max_new_tokens, use_cache=True) | ||||||||||
| times_cache.append(time.time() - start) | ||||||||||
|
|
||||||||||
| avg_cache = sum(times_cache) / num_runs | ||||||||||
| speed_cache = max_new_tokens / avg_cache | ||||||||||
| print(f"With Cache: {speed_cache:.2f} tok/s") | ||||||||||
|
|
||||||||||
| times_no_cache = [] | ||||||||||
| for _ in range(num_runs): | ||||||||||
| start = time.time() | ||||||||||
| model.generate(input_ids, max_new_tokens=max_new_tokens, use_cache=False) | ||||||||||
| times_no_cache.append(time.time() - start) | ||||||||||
|
|
||||||||||
| avg_no_cache = sum(times_no_cache) / num_runs | ||||||||||
| speed_no_cache = max_new_tokens / avg_no_cache | ||||||||||
| print(f"No Cache: {speed_no_cache:.2f} tok/s") | ||||||||||
|
|
||||||||||
| plt.figure() | ||||||||||
| plt.bar( | ||||||||||
| ["Without Cache", "With Cache"], | ||||||||||
| [speed_no_cache, speed_cache], | ||||||||||
| color=["red", "blue"], | ||||||||||
| ) | ||||||||||
| plt.ylabel("Tokens / Second") | ||||||||||
| plt.title("Inference Speed Comparison") | ||||||||||
| plt.savefig("kv_cache_vs_native_benchmark.png") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def run_audit(): | ||||||||||
| print("Removing Causal Mask") | ||||||||||
|
|
||||||||||
| with open("data/trajectories.pkl", "rb") as f: | ||||||||||
| trajectories = pickle.load(f) | ||||||||||
| dataset = trajectories["actions"].clone().detach().long()[:1000] | ||||||||||
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) | ||||||||||
|
|
||||||||||
| class CheatingTransformer(DecoderOnlyTransformer): | ||||||||||
| def forward(self, input_ids, targets=None, past_kv=None, use_cache=False): | ||||||||||
| batch_size, seq_len = input_ids.shape | ||||||||||
| x = self.token_embedding(input_ids) | ||||||||||
| if not self.use_rope: | ||||||||||
| x = self.pos_embedding(x) | ||||||||||
| mask = torch.ones(seq_len, seq_len, device=x.device) | ||||||||||
|
|
||||||||||
| for block in self.blocks: | ||||||||||
| x, _ = block(x, mask) | ||||||||||
|
|
||||||||||
| x = self.norm_final(x) | ||||||||||
| logits = self.lm_head(x) | ||||||||||
|
|
||||||||||
| loss = None | ||||||||||
| if targets is not None: | ||||||||||
| loss = torch.nn.functional.cross_entropy( | ||||||||||
| logits.view(-1, self.vocab_size), targets.view(-1), ignore_index=-1 | ||||||||||
| ) | ||||||||||
| return logits, loss, None | ||||||||||
|
|
||||||||||
| model = CheatingTransformer(**PARAMS, use_rope=True).to(DEVICE) | ||||||||||
| optimizer = optim.AdamW(model.parameters(), lr=1e-3) | ||||||||||
|
|
||||||||||
| print("Training Cheating Model (Bidirectional Attention)...") | ||||||||||
| for epoch in range(1): | ||||||||||
| loss, _ = train_epoch(model, dataloader, optimizer, DEVICE, epoch) | ||||||||||
| print(f"Epoch {epoch}, Loss: {loss:.4f} (Expected << Initial Loss)") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def run_visualization(): | ||||||||||
| model = DecoderOnlyTransformer(**PARAMS).to(DEVICE) | ||||||||||
|
|
||||||||||
| if os.path.exists("checkpoints/best_model.pt"): | ||||||||||
| try: | ||||||||||
| state = torch.load("checkpoints/best_model.pt", map_location=DEVICE) | ||||||||||
| model_keys = model.state_dict().keys() | ||||||||||
| filtered_state = { | ||||||||||
| k: v | ||||||||||
| for k, v in state.items() | ||||||||||
| if k in model_keys and v.size() == model.state_dict()[k].size() | ||||||||||
| } | ||||||||||
| model.load_state_dict(filtered_state, strict=False) | ||||||||||
| print("Loaded pretrained weights.") | ||||||||||
| except: | ||||||||||
| print("Using random weights.") | ||||||||||
|
Comment on lines
+217
to
+218
|
||||||||||
| except: | |
| print("Using random weights.") | |
| except Exception as e: | |
| print(f"Failed to load pretrained weights, using random weights instead. Error: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'nn' is not used.