diff --git a/README.md b/README.md index ee4e0c7..bb6ccd0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Happy to announce that we received the best paper award of this workshop! **Abstract**: -*Compressing neural network architectures is important to allow the deployment of models to embedded or mobile devices, +_Compressing neural network architectures is important to allow the deployment of models to embedded or mobile devices, and pruning and quantization are the major approaches to compress neural networks nowadays. Both methods benefit when compression parameters are selected specifically for each layer. Finding good combinations of compression parameters, so-called compression policies, is hard as the problem spans an exponentially large search space. Effective compression @@ -24,7 +24,7 @@ compression of models specific to a given hardware target. We validate our appro learning agents for pruning, quantization and joint pruning and quantization. Besides proving the functionality of our approach we were able to compress a ResNet18 for CIFAR-10, on an embedded ARM processor, to 20% of the original inference latency without significant loss of accuracy. Moreover, we can demonstrate that a joint search and compression -using pruning and quantization is superior to an individual search for policies using a single compression method.* +using pruning and quantization is superior to an individual search for policies using a single compression method._ ![Algorithmic Schema](./figures/alg_schema.drawio.svg) @@ -47,7 +47,7 @@ without hardware feedback you could install using pip: ```shell # CPU build only -pip install apache-tvm +pip install apache-tvm ``` #### Manual installation @@ -69,7 +69,7 @@ cd build cmake .. # -j specifies the number of compile threads -make -j4 +make -j4 ``` To make the TVM python library usable on your system, add the following to your `.bashrc` (or `.zshrc` (...)): @@ -148,6 +148,8 @@ Finally, using the joint agent: bash ./scripts/search_pq.sh ``` +To deactivate measurement of hardware latency, add `enable_latency_eval=False` to the `--alg_config` argument when using the scripts. + [1] https://tvm.apache.org/docs/tutorial/cross_compilation_and_rpc.html ### Apache-TVM: Missing Property `kernel.data.shape` @@ -158,8 +160,33 @@ TVM python package. - navigate to the above cloned TVM repository on your machine - open the file `python/tvm/topi/arm_cpu/bitserial_conv2d.py` - - comment out the if statement in line 455 (`if len(kernel.data.shape) == 4:`) - - fix indention for line 456 to 467 + - comment out the if statement in line 455 (`if len(kernel.data.shape) == 4:`) + - fix indention for line 456 to 467 + +# Prune Pretrained Models + +## Method 1 (preferred) + +Define additional models at `tools/util/model_provider.py`. Provide a checkpoint file and reference it using the `--ckpt_load_path` argument when using scripts (see `scripts/search_p_custom_checkpoint.sh`). +The checkpoint is typically saved using `torch.save(model.state_dict(), PATH)` and contains only the network weights. + +## Method 2 + +Provide a pretrained model and reference it using the `--model` argument when using scripts (see `scripts/search_p_custom_model.sh`). +The model is typically saved using `torch.save(model, PATH)` and contains the whole model. + +## Retrain + +Do not forget to retrain your model after pruning using the `scripts/retrain.sh` script. Specify the model using the `--model` argument, the checkpoint using the `--ckpt_load_path` argument and the pruning policy generated during the pruning search using the `--policy` argument. + +## Reward function + +Multiple reward functions can be specified using e.g the `reward=r6` argument for the R6 reward. For more details, take a look at the definitions at `runtime/agent/reward.py` and the script at `scripts/search_p.sh`. + +Specify the `reward_target_cost_ratio=` as a target for model-complexity. A value of 0.25 means, that the algorithm should reduce the model-complexity to 25% of the original model-complexity. +For the reward functions, the beta value can be defined using e.g. the`r6_beta=` argument in case of the R6 reward. A beta value of -5 puts more emphasis on complexity reduction, a value of -1 puts more emphasis on accuracy. +## Dataset +Currently only cifar10 and imagenet are supported as datasets. More datasets can be added at `runtime/data/data_provider.py`. diff --git a/environment.yml b/environment.yml index a5b2031..6b41a47 100644 --- a/environment.yml +++ b/environment.yml @@ -28,7 +28,7 @@ dependencies: - matplotlib - pandas - pip: - - torch-pruning + - torch-pruning==0.2.8 - fvcore - numpy-ringbuffer - decorator diff --git a/runtime/agent/reward.py b/runtime/agent/reward.py index 762667d..e7a1eb9 100644 --- a/runtime/agent/reward.py +++ b/runtime/agent/reward.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs): self.r6_beta = float(kwargs.pop("r6_beta", -5.0)) self.r7_beta = float(kwargs.pop("r7_beta", -5.0)) self.reward_target_cost_ratio = float(kwargs.pop("reward_target_cost_ratio", 0.5)) - self.episode_cost_key = kwargs.pop("reward_episode_cost_key", "lat") + self.episode_cost_key = kwargs.pop("reward_episode_cost_key", "lat") # Has to be overwritten in scripts with "reward_episode_cost_key=BOPs" if latency evaluation is switched off using "enable_latency_eval=False" self.step_cost_key = kwargs.pop("reward_step_cost_key", "BOPs") self.acc_key = kwargs.pop("reward_acc_key", "acc") self.reward = kwargs.pop("reward", "r6") diff --git a/runtime/compress/torch_compress/model_info.py b/runtime/compress/torch_compress/model_info.py index b35bf9f..7955a9f 100644 --- a/runtime/compress/torch_compress/model_info.py +++ b/runtime/compress/torch_compress/model_info.py @@ -16,7 +16,11 @@ def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, paren key_elements = layer_key.split(".") if len(key_elements) > 1: - parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer} + if parent_info.parent_info: + parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer and info.parent_info.var_name == parent_info.var_name} + else: + parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer} + if key_elements[0] in parents: current_info = parents[key_elements[0]] return self._get_info(".".join(key_elements[1:]), current_info.children, full_key, current_info) @@ -26,4 +30,4 @@ def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, paren if key_elements[0] in leafs: return leafs[key_elements[0]] - raise Exception(f"Could not resolve layer info for {full_key} - step failed for part {'.'.join(key_elements)}") + raise Exception(f"Could not resolve layer info for {full_key} - step failed for part {'.'.join(key_elements)}") \ No newline at end of file diff --git a/runtime/compress/torch_compress/torch_executors.py b/runtime/compress/torch_compress/torch_executors.py index 3d30321..f5b3669 100644 --- a/runtime/compress/torch_compress/torch_executors.py +++ b/runtime/compress/torch_compress/torch_executors.py @@ -225,6 +225,8 @@ def _is_layer_compatible(self, layer_key, model_info) -> bool: return False return True if isinstance(layer_info.module, torch.nn.Conv2d): + if layer_info.output_size == []: + return False min_output_size = min(layer_info.output_size[2], layer_info.output_size[3]) if min_output_size < 2: return False diff --git a/runtime/log/logging.py b/runtime/log/logging.py index 0478545..18c1778 100644 --- a/runtime/log/logging.py +++ b/runtime/log/logging.py @@ -161,7 +161,7 @@ def _include_ratios(self, compression_evaluation, prefix): for key, value in compression_evaluation.items(): evaluation[f"{prefix}-{key}"] = value if key in self._initial_evaluation: - evaluation[f"{prefix}-{key}-ratio"] = value / self._initial_evaluation[key] + evaluation[f"{prefix}-{key}-ratio"] = value / self._initial_evaluation[key] # ratio describes which percentage of the original complexity the compressed model has return evaluation @staticmethod diff --git a/runtime/model/torch_model.py b/runtime/model/torch_model.py index 10986ec..7603acb 100644 --- a/runtime/model/torch_model.py +++ b/runtime/model/torch_model.py @@ -29,6 +29,9 @@ def __init__(self, self._frozen_layers = frozen_layers self._layer_dict = {layer_key: module for layer_key, module in self._reference_model.named_modules() if not [*module.children()]} + last_layer = list(self._layer_dict.keys())[-1] + if "p-lin" in self._frozen_layers: + self._frozen_layers["p-lin"].append(last_layer) # disable pruning of last layer def all_layer_keys(self) -> list[str]: return [*self._layer_dict] diff --git a/scripts/search_p.sh b/scripts/search_p.sh old mode 100644 new mode 100755 diff --git a/scripts/search_p_custom_checkpoint.sh b/scripts/search_p_custom_checkpoint.sh new file mode 100755 index 0000000..aa94674 --- /dev/null +++ b/scripts/search_p_custom_checkpoint.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +python -m tools.search_policy \ + --model resnet18_pretrained \ + --ckpt_load_path ./results/checkpoints/resnet18_pretrained/resnet18_pretrained_pre-train_lr0.05_mom0.9_ep93.pth \ + --log_dir ./logs/resnet18_pretrained \ + --agent independent-single-layer-pruning \ + --episodes 410 \ + --add_search_identifier resnet18_pretrained \ + --alg_config num_workers=6 reward=r6 r6_beta=-5 mixed_reference_bits=6 reward_target_cost_ratio=0.25 enable_latency_eval=False reward_episode_cost_key=BOPs diff --git a/scripts/search_p_custom_model.sh b/scripts/search_p_custom_model.sh new file mode 100755 index 0000000..a5530a0 --- /dev/null +++ b/scripts/search_p_custom_model.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +python -m tools.search_policy \ + --model ./pretrained_models/densenet121.pth \ + --log_dir ./logs/densenet121 \ + --agent independent-single-layer-pruning \ + --episodes 410 \ + --add_search_identifier densenet121 \ + --alg_config num_workers=6 reward=r6 r6_beta=-5 mixed_reference_bits=6 reward_target_cost_ratio=0.25 enable_latency_eval=False reward_episode_cost_key=BOPs diff --git a/tools/train.py b/tools/train.py index aea6877..bb84dce 100644 --- a/tools/train.py +++ b/tools/train.py @@ -208,7 +208,7 @@ def parse_arguments() -> Namespace: log_file_name=args.log_name ) - wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), + wandb.init(project=args.wandb_project, config=vars(args), name=trainer.create_identifier(args.epochs)) protocol = None diff --git a/tools/util/model_provider.py b/tools/util/model_provider.py index d9162e0..c7e7dfa 100644 --- a/tools/util/model_provider.py +++ b/tools/util/model_provider.py @@ -1,7 +1,8 @@ import torch import torch.hub from torch import nn - +from pathlib import Path +import torchvision class TestModel(nn.Module): def __init__(self): @@ -28,18 +29,37 @@ def resnet18_cifar(): model.maxpool = nn.Identity() return model +def mobile_net_pretrained(): + model = torchvision.models.mobilenet_v2(pretrained=True) + num_classes = 10 + model.classifier[1] = torch.nn.Linear(1280, num_classes) + return model + +def resnet18_pretrained(): + model = torchvision.models.resnet18(pretrained=True) + num_features = model.fc.in_features + model.fc = nn.Linear(num_features, 10) + return model provider = { "test_model": test_model, - "resnet18_cifar": resnet18_cifar + "resnet18_cifar": resnet18_cifar, + "mobile_net_pretrained": mobile_net_pretrained, + "resnet18_pretrained": resnet18_pretrained } def load_model(select_str, num_classes, checkpoint_path=None): if "@" in select_str: + # model on torch hub name, repo = select_str.split("@") model = torch.hub.load(repo, name, pretrained=True, num_classes=num_classes) + elif "/" in select_str: + # local path to pretrained model containing architecture definition + model = torch.load(select_str) + name = Path(select_str).stem else: + # predefined architecture definition model = provider[select_str]() name = select_str @@ -52,4 +72,6 @@ def load_checkpoint(checkpoint_path): if checkpoint_path.endswith('.lightning.ckpt'): state_dict = torch.load(checkpoint_path)['state_dict'] return {key[6:]: weight for key, weight in state_dict.items()} + if checkpoint_path.startswith('pretrained_checkpoints'): + return torch.load(checkpoint_path)['model_state_dict'] return torch.load(checkpoint_path)