This repository is based on DistillSpec: Improving Speculative Decoding via Knowledge Distillation (Zhou et. al.).
Course: Columbia University COMS4705 (Natural Language Processing)
This project investigates the impact of different divergence objectives-Forward KL (FKL), Reverse KL (RKL), and Jensen-Shannon Divergence (JSD)-on training draft models for Speculative Decoding (SD).
Standard off-policy distillation suffers from exposure bias. We utilize On-Policy Distillation, where the student learns from its own generated trajectories scored by a frozen teacher . We demonstrate that the optimal divergence metric is highly dependent on the entropy profile of the downstream task.
A speculative decoding framework which supports batched inputs was implemented first to enable fast evaluation. Implementation is available at batched_specdec and is added as a submodule.
Our experiments show that Mode-Seeking behavior is preferred for reasoning, while Mean-Seeking/Balanced behavior is preferred for open-ended generation .
| Task Type | Dataset | Teacher Entropy | Best Metric | Insight |
|---|---|---|---|---|
| Math Reasoning | GSM8k | Low (Deterministic) | Reverse KL (RKL) | RKL forces the student to "snap" to the correct reasoning path, ignoring valid but unaligned tokens. |
| Summarization | CNN/DM | High (Ambiguous) | JSD | JSD balances coverage and precision, preventing collapse in high-entropy regions. |
-
Qwen3 (GSM8k): RKL achieved 53.93%, outperforming Baseline (49.52%) and FKL.
-
SmolLM (CNN/DM): JSD achieved 55.35%, outperforming RKL and FKL.
Models
- Math: Teacher:
Qwen3-4B-Instruct| Student:Qwen3-0.6B-Instruct - Summarization: Teacher:
SmolLM-1.7B-Instruct| Student:SmolLM-360M-Instruct
We use a white-box, token-level distillation framework using **Hugging Face TRL's GKDTrainer**.
-
Zero-Label Training: No ground-truth dataset labels were used; training relied entirely on student-generated trajectories scored by the teacher.
-
Metrics: We evaluate using Token-Level and Sequence-Level Acceptance Rates.
The project includes a custom evaluation harness for Speculative Decoding with Dynamic Batching.
- Python 3.11+
- PyTorch (CUDA or MPS)
- Hugging Face
transformers,trl,bitsandbytes
Best performing checkpoints are available here:
| Model | Dataset | Divergence | Link |
|---|---|---|---|
| Qwen-0.6B | GSM8K | Reverse KL | rishabhrj11/distillspec-qwen600 |
| SmolLM-350M | CNNDM | JSD | rishabhrj11/distillspec-smollm-cnn-b5 |
We analyze performance relative to Teacher Entropy.
-
Low Entropy: RKL dominates, maximizing precision.
-
High Entropy: RKL degrades; JSD remains robust where diversity is required .