Skip to content

Using On-Policy Knowledge Distillation for Accelerating Speculative Decoding

Notifications You must be signed in to change notification settings

r-rishabh-j/distillSpec

Repository files navigation

Forward or Reverse KL? Exploring On-Policy Distillation for Speculative Decoding

This repository is based on DistillSpec: Improving Speculative Decoding via Knowledge Distillation (Zhou et. al.).


Course: Columbia University COMS4705 (Natural Language Processing)

📖 Overview

This project investigates the impact of different divergence objectives-Forward KL (FKL), Reverse KL (RKL), and Jensen-Shannon Divergence (JSD)-on training draft models for Speculative Decoding (SD).

Standard off-policy distillation suffers from exposure bias. We utilize On-Policy Distillation, where the student learns from its own generated trajectories scored by a frozen teacher . We demonstrate that the optimal divergence metric is highly dependent on the entropy profile of the downstream task.

Batched Speculative Decoding for Fast Eval

A speculative decoding framework which supports batched inputs was implemented first to enable fast evaluation. Implementation is available at batched_specdec and is added as a submodule.

🚀 Key Findings

Our experiments show that Mode-Seeking behavior is preferred for reasoning, while Mean-Seeking/Balanced behavior is preferred for open-ended generation .

Task Type Dataset Teacher Entropy Best Metric Insight
Math Reasoning GSM8k Low (Deterministic) Reverse KL (RKL) RKL forces the student to "snap" to the correct reasoning path, ignoring valid but unaligned tokens.
Summarization CNN/DM High (Ambiguous) JSD JSD balances coverage and precision, preventing collapse in high-entropy regions.

Quantitative Results (Token Acceptance Rate)

  • Qwen3 (GSM8k): RKL achieved 53.93%, outperforming Baseline (49.52%) and FKL.

  • SmolLM (CNN/DM): JSD achieved 55.35%, outperforming RKL and FKL.

🛠️ Methodology

Models

  • Math: Teacher: Qwen3-4B-Instruct | Student: Qwen3-0.6B-Instruct
  • Summarization: Teacher: SmolLM-1.7B-Instruct | Student: SmolLM-360M-Instruct

Distillation Approach

We use a white-box, token-level distillation framework using **Hugging Face TRL's GKDTrainer**.

  • Zero-Label Training: No ground-truth dataset labels were used; training relied entirely on student-generated trajectories scored by the teacher.

  • Metrics: We evaluate using Token-Level and Sequence-Level Acceptance Rates.

💻 Implementation Details

The project includes a custom evaluation harness for Speculative Decoding with Dynamic Batching.

Requirements

  • Python 3.11+
  • PyTorch (CUDA or MPS)
  • Hugging Face transformers, trl, bitsandbytes

Checkpoints

Best performing checkpoints are available here:

Model Dataset Divergence Link
Qwen-0.6B GSM8K Reverse KL rishabhrj11/distillspec-qwen600
SmolLM-350M CNNDM JSD rishabhrj11/distillspec-smollm-cnn-b5

📊 Analysis

We analyze performance relative to Teacher Entropy.

  • Low Entropy: RKL dominates, maximizing precision.

  • High Entropy: RKL degrades; JSD remains robust where diversity is required .

About

Using On-Policy Knowledge Distillation for Accelerating Speculative Decoding

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages