This repository implements knowledge distillation for Automatic Speech Recognition (ASR) models using NVIDIA NeMo framework. The project focuses on training smaller, more efficient FastConformer-Transducer models by distilling knowledge from larger teacher models.
Knowledge distillation is a technique for transferring knowledge from a large, complex model (teacher) to a smaller, more efficient model (student). This project implements knowledge distillation for speech recognition models, enabling the creation of compact models that maintain competitive performance while requiring fewer computational resources.
- FastConformer-Transducer Architecture: State-of-the-art streaming ASR model
- Knowledge Distillation: Transfer learning from teacher to student models
- ONNX Export: Model deployment with ONNX runtime for efficient inference
- LibriSpeech Training: Comprehensive training on LibriSpeech dataset
- Multiple Model Sizes: Support for different model architectures (Small, Medium, Large)
knowledgedistill/
βββ README.md # This file
βββ train.py # Main training script
βββ base.yaml # Base configuration file
βββ fast-conformer_transducer_bpe.yaml # Large model configuration
βββ fast-conformer_transducer_bpe_medium.yaml # Medium model configuration (with KD)
βββ manifest/ # Dataset manifest files
β βββ train_manifest.json
β βββ val_manifest.json
β βββ test_clean_manifest.json
β βββ test_other_manifest.json
βββ scripts/ # Utility scripts
β βββ generate_manifest.py # Generate LibriSpeech manifests
β βββ evaluate.py # Model evaluation
β βββ export_onnx.py # Export models to ONNX
β βββ inference_onnx.py # ONNX inference
β βββ extract_tokenizer.py # Tokenizer utilities
βββ tokenizer/ # BPE tokenizer files
β βββ tokenizer.model
β βββ tokenizer.vocab
β βββ vocab.txt
βββ teacher/ # Teacher model storage
β βββ teacher_model.nemo
βββ models/ # Trained models
β βββ full_train.nemo
βββ experiments/ # Training experiment logs
β βββ base/ # Base model experiments
β βββ new/ # New model experiments
β βββ whole_train/ # Full training experiments
β βββ wandb/ # Weights & Biases logs
βββ Nemo-CM/ # NeMo framework (submodule)
- Python 3.8+
- NVIDIA GPU with CUDA support
- NeMo framework
- PyTorch
- LibriSpeech dataset
- Clone the repository:
git clone <repository-url>
cd knowledgedistill- Install dependencies:
pip install nemo_toolkit
pip install -r requirements.txt # if available- Download LibriSpeech dataset and update paths in configuration files.
- Generate manifest files for LibriSpeech:
python scripts/generate_manifest.pyThis script processes LibriSpeech directories and creates manifest files for:
- Training data (train-clean-100, train-clean-360, train-other-500)
- Validation data (dev-clean, dev-other)
- Test data (test-clean, test-other)
Train a FastConformer model without knowledge distillation:
python train.py --config-path=. --config-name=fast-conformer_transducer_bpeTrain a student model with knowledge distillation:
python train.py --config-path=. --config-name=fast-conformer_transducer_bpe_medium \
model.enable_kd=True \
model.teacher_model_path=/path/to/teacher_model.nemo \
model.kd_temperature=4.0 \
model.kd_alpha=0.7enable_kd: Enable/disable knowledge distillation (default: False)teacher_model_path: Path to the pre-trained teacher modelkd_temperature: Temperature for softening probability distributions (default: 1.0)kd_alpha: Weight balancing between distillation loss and ground truth loss (default: 0.5)
| Model Size | d_model | n_heads | n_layers | Parameters | Config File |
|---|---|---|---|---|---|
| Small | 176 | 4 | 16 | ~14M | Custom |
| Medium | 256 | 4 | 16 | ~32M | medium.yaml |
| Large | 512 | 8 | 17 | ~120M | base.yaml |
Evaluate trained models:
python scripts/evaluate.pyThis script:
- Loads the trained model
- Transcribes test audio files
- Computes Word Error Rate (WER) and Character Error Rate (CER)
- Saves transcriptions for analysis
Export trained models to ONNX format for efficient inference:
python scripts/export_onnx.pyThis creates:
encoder.onnx: Encoder modeldecoder.onnx: Decoder modeljoiner.onnx: Joint networkpreprocessor.ts: TorchScript preprocessortokens.txt: Vocabulary file
Run inference with exported ONNX models:
python scripts/inference_onnx.py --audio_file /path/to/audio.wavThe project uses Weights & Biases (wandb) for experiment tracking:
- Training metrics (loss, WER, learning rate)
- Model configurations
- Hyperparameter sweeps
- Experiment comparison
Configure wandb in the experiment manager section of config files:
exp_manager:
create_wandb_logger: true
wandb_logger_kwargs:
name: experiment_name
project: project_name- Large FastConformer model (512 d_model, 17 layers)
- Standard training without knowledge distillation
- Optimized for high accuracy
- Medium FastConformer model (256 d_model, 16 layers)
- Knowledge distillation support
- Balanced between efficiency and accuracy
Key differences:
- Smaller model architecture (32M vs 120M parameters)
- Knowledge distillation parameters
- Adjusted learning rates and batch sizes
-
Batch Size: Adjust based on GPU memory:
- 16GB GPU: batch_size=8-16
- 32GB GPU: batch_size=16-32
- 80GB GPU: batch_size=32-64
-
Knowledge Distillation:
- Use
kd_temperature=3-5for better knowledge transfer - Balance
kd_alphabetween 0.3-0.7 depending on teacher quality - Ensure teacher and student use the same tokenizer
- Use
-
Training Duration:
- Medium models: 50-100 epochs
- Large models: 100-500 epochs
- Monitor validation WER for early stopping
The knowledge distillation approach typically achieves:
- Model Size Reduction: 70-80% parameter reduction (120M β 32M)
- Performance Retention: 90-95% of teacher model accuracy
- Inference Speed: 2-3x faster inference
- Memory Usage: 60-70% reduction in GPU memory
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests if applicable
- Submit a pull request
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
- NVIDIA NeMo team for the excellent ASR framework
- LibriSpeech corpus for training data
- FastConformer architecture contributors
For questions or issues, please open an issue in the repository or contact the maintainers.