This script trains a transfer learning model, prunes it using a selected method, and optionally recovers its performance using fine-tuning or neural-mimicking.
Install dependencies:
pip install torch torchvisionRun the script using:
python main.py --dataset_name DATASET --model_name MODEL --pruning_criteria CRITERIA --recovery_method RECOVERY [--train]--dataset_name: Dataset to use. Options:mnist,cifar10,cifar100,imagenette,tiny_imagenet,imagenet1k
--model_name: Model architecture. Options:resnet18,resnet50,alexnet,squeezenet,vit
--pruning_criteria: Pruning method. Options:l1,random,lobs,nnrelief_paper
--recovery_method: Recovery method after pruning. Options:nonefine_tuningneural_mimicking_lbfgsneural_mimicking_qrneural_mimicking_svdneural_mimicking_sgd
--train: If set, the model will be trained before pruning; otherwise, it will load weights from./weights/MODEL_DATASET.pth.
python main.py --dataset_name cifar100 --model_name vit --pruning_criteria l1 --recovery_method fine_tuning --trainpython main.py --dataset_name cifar100 --model_name vit --pruning_criteria lobs --recovery_method neural_mimicking_lbfgs- Training and evaluation logs
- Accuracy after each pruning rate (10% to 90%)
- Saved weights at:
./weights/MODEL_DATASET.pth