Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Happy to announce that we received the best paper award of this workshop!

**Abstract**:

*Compressing neural network architectures is important to allow the deployment of models to embedded or mobile devices,
_Compressing neural network architectures is important to allow the deployment of models to embedded or mobile devices,
and pruning and quantization are the major approaches to compress neural networks nowadays. Both methods benefit when
compression parameters are selected specifically for each layer. Finding good combinations of compression parameters,
so-called compression policies, is hard as the problem spans an exponentially large search space. Effective compression
Expand All @@ -24,7 +24,7 @@ compression of models specific to a given hardware target. We validate our appro
learning agents for pruning, quantization and joint pruning and quantization. Besides proving the functionality of our
approach we were able to compress a ResNet18 for CIFAR-10, on an embedded ARM processor, to 20% of the original
inference latency without significant loss of accuracy. Moreover, we can demonstrate that a joint search and compression
using pruning and quantization is superior to an individual search for policies using a single compression method.*
using pruning and quantization is superior to an individual search for policies using a single compression method._

![Algorithmic Schema](./figures/alg_schema.drawio.svg)

Expand All @@ -47,7 +47,7 @@ without hardware feedback you could install using pip:

```shell
# CPU build only
pip install apache-tvm
pip install apache-tvm
```

#### Manual installation
Expand All @@ -69,7 +69,7 @@ cd build
cmake ..

# -j specifies the number of compile threads
make -j4
make -j4
```

To make the TVM python library usable on your system, add the following to your `.bashrc` (or `.zshrc` (...)):
Expand Down Expand Up @@ -148,6 +148,8 @@ Finally, using the joint agent:
bash ./scripts/search_pq.sh
```

To deactivate measurement of hardware latency, add `enable_latency_eval=False` to the `--alg_config` argument when using the scripts.

[1] https://tvm.apache.org/docs/tutorial/cross_compilation_and_rpc.html

### Apache-TVM: Missing Property `kernel.data.shape`
Expand All @@ -158,8 +160,33 @@ TVM python package.

- navigate to the above cloned TVM repository on your machine
- open the file `python/tvm/topi/arm_cpu/bitserial_conv2d.py`
- comment out the if statement in line 455 (`if len(kernel.data.shape) == 4:`)
- fix indention for line 456 to 467
- comment out the if statement in line 455 (`if len(kernel.data.shape) == 4:`)
- fix indention for line 456 to 467

# Prune Pretrained Models

## Method 1 (preferred)

Define additional models at `tools/util/model_provider.py`. Provide a checkpoint file and reference it using the `--ckpt_load_path` argument when using scripts (see `scripts/search_p_custom_checkpoint.sh`).
The checkpoint is typically saved using `torch.save(model.state_dict(), PATH)` and contains only the network weights.

## Method 2

Provide a pretrained model and reference it using the `--model` argument when using scripts (see `scripts/search_p_custom_model.sh`).
The model is typically saved using `torch.save(model, PATH)` and contains the whole model.

## Retrain

Do not forget to retrain your model after pruning using the `scripts/retrain.sh` script. Specify the model using the `--model` argument, the checkpoint using the `--ckpt_load_path` argument and the pruning policy generated during the pruning search using the `--policy` argument.

## Reward function

Multiple reward functions can be specified using e.g the `reward=r6` argument for the R6 reward. For more details, take a look at the definitions at `runtime/agent/reward.py` and the script at `scripts/search_p.sh`.

Specify the `reward_target_cost_ratio=<n>` as a target for model-complexity. A value of 0.25 means, that the algorithm should reduce the model-complexity to 25% of the original model-complexity.

For the reward functions, the beta value can be defined using e.g. the`r6_beta=<n>` argument in case of the R6 reward. A beta value of -5 puts more emphasis on complexity reduction, a value of -1 puts more emphasis on accuracy.

## Dataset

Currently only cifar10 and imagenet are supported as datasets. More datasets can be added at `runtime/data/data_provider.py`.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- matplotlib
- pandas
- pip:
- torch-pruning
- torch-pruning==0.2.8
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch-pruning API changed since version 0.2.8 requiring a refactoring of the galen code.

- fvcore
- numpy-ringbuffer
- decorator
Expand Down
2 changes: 1 addition & 1 deletion runtime/agent/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, **kwargs):
self.r6_beta = float(kwargs.pop("r6_beta", -5.0))
self.r7_beta = float(kwargs.pop("r7_beta", -5.0))
self.reward_target_cost_ratio = float(kwargs.pop("reward_target_cost_ratio", 0.5))
self.episode_cost_key = kwargs.pop("reward_episode_cost_key", "lat")
self.episode_cost_key = kwargs.pop("reward_episode_cost_key", "lat") # Has to be overwritten in scripts with "reward_episode_cost_key=BOPs" if latency evaluation is switched off using "enable_latency_eval=False"
self.step_cost_key = kwargs.pop("reward_step_cost_key", "BOPs")
self.acc_key = kwargs.pop("reward_acc_key", "acc")
self.reward = kwargs.pop("reward", "r6")
Expand Down
8 changes: 6 additions & 2 deletions runtime/compress/torch_compress/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, paren
key_elements = layer_key.split(".")

if len(key_elements) > 1:
parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer}
if parent_info.parent_info:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, some models threw an exception. During the traversal of the model layers, when listing the childs of a parent, all childs of childs were not excluded. This led to wrong traversal of the layer tree. I excluded all childs of childs by adding the condition and info.parent_info.var_name == parent_info.var_name

parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer and info.parent_info.var_name == parent_info.var_name}
else:
parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer}

if key_elements[0] in parents:
current_info = parents[key_elements[0]]
return self._get_info(".".join(key_elements[1:]), current_info.children, full_key, current_info)
Expand All @@ -26,4 +30,4 @@ def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, paren
if key_elements[0] in leafs:
return leafs[key_elements[0]]

raise Exception(f"Could not resolve layer info for {full_key} - step failed for part {'.'.join(key_elements)}")
raise Exception(f"Could not resolve layer info for {full_key} - step failed for part {'.'.join(key_elements)}")
2 changes: 2 additions & 0 deletions runtime/compress/torch_compress/torch_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def _is_layer_compatible(self, layer_key, model_info) -> bool:
return False
return True
if isinstance(layer_info.module, torch.nn.Conv2d):
if layer_info.output_size == []:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unhandled state of output_size I ran into.

return False
min_output_size = min(layer_info.output_size[2], layer_info.output_size[3])
if min_output_size < 2:
return False
Expand Down
2 changes: 1 addition & 1 deletion runtime/log/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _include_ratios(self, compression_evaluation, prefix):
for key, value in compression_evaluation.items():
evaluation[f"{prefix}-{key}"] = value
if key in self._initial_evaluation:
evaluation[f"{prefix}-{key}-ratio"] = value / self._initial_evaluation[key]
evaluation[f"{prefix}-{key}-ratio"] = value / self._initial_evaluation[key] # ratio describes which percentage of the original complexity the compressed model has
return evaluation

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions runtime/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self,
self._frozen_layers = frozen_layers
self._layer_dict = {layer_key: module for layer_key, module in self._reference_model.named_modules() if
not [*module.children()]}
last_layer = list(self._layer_dict.keys())[-1]
Copy link
Author

@J-Gann J-Gann Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discovered, that galen assumes, that the last layer of the model is named "fc" as stated e.g here. This leads to unexpected and difficult to resolve errors during pruning. I propose to always add the last layer of the network to the list of frozen layers. Alternatively, it should be included in the documentation.

if "p-lin" in self._frozen_layers:
self._frozen_layers["p-lin"].append(last_layer) # disable pruning of last layer

def all_layer_keys(self) -> list[str]:
return [*self._layer_dict]
Expand Down
Empty file modified scripts/search_p.sh
100644 → 100755
Empty file.
10 changes: 10 additions & 0 deletions scripts/search_p_custom_checkpoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env bash

python -m tools.search_policy \
--model resnet18_pretrained \
--ckpt_load_path ./results/checkpoints/resnet18_pretrained/resnet18_pretrained_pre-train_lr0.05_mom0.9_ep93.pth \
--log_dir ./logs/resnet18_pretrained \
--agent independent-single-layer-pruning \
--episodes 410 \
--add_search_identifier resnet18_pretrained \
--alg_config num_workers=6 reward=r6 r6_beta=-5 mixed_reference_bits=6 reward_target_cost_ratio=0.25 enable_latency_eval=False reward_episode_cost_key=BOPs
9 changes: 9 additions & 0 deletions scripts/search_p_custom_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env bash

python -m tools.search_policy \
--model ./pretrained_models/densenet121.pth \
--log_dir ./logs/densenet121 \
--agent independent-single-layer-pruning \
--episodes 410 \
--add_search_identifier densenet121 \
--alg_config num_workers=6 reward=r6 r6_beta=-5 mixed_reference_bits=6 reward_target_cost_ratio=0.25 enable_latency_eval=False reward_episode_cost_key=BOPs
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def parse_arguments() -> Namespace:
log_file_name=args.log_name
)

wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args),
wandb.init(project=args.wandb_project, config=vars(args),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

entity=args.wandb_entity leads to permission problems on the cluster if not running as superuser. The reason is, that wandb tries to access the /tmp folder. Deleting this argument resolves the problem.

name=trainer.create_identifier(args.epochs))

protocol = None
Expand Down
26 changes: 24 additions & 2 deletions tools/util/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.hub
from torch import nn

from pathlib import Path
import torchvision

class TestModel(nn.Module):
def __init__(self):
Expand All @@ -28,18 +29,37 @@ def resnet18_cifar():
model.maxpool = nn.Identity()
return model

def mobile_net_pretrained():
model = torchvision.models.mobilenet_v2(pretrained=True)
num_classes = 10
model.classifier[1] = torch.nn.Linear(1280, num_classes)
return model

def resnet18_pretrained():
model = torchvision.models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)
return model

provider = {
"test_model": test_model,
"resnet18_cifar": resnet18_cifar
"resnet18_cifar": resnet18_cifar,
"mobile_net_pretrained": mobile_net_pretrained,
"resnet18_pretrained": resnet18_pretrained
}


def load_model(select_str, num_classes, checkpoint_path=None):
if "@" in select_str:
# model on torch hub
name, repo = select_str.split("@")
model = torch.hub.load(repo, name, pretrained=True, num_classes=num_classes)
elif "/" in select_str:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposal for an additional method for loading pretrained models which were saved using torch.save(model, PATH).

# local path to pretrained model containing architecture definition
model = torch.load(select_str)
name = Path(select_str).stem
else:
# predefined architecture definition
model = provider[select_str]()
name = select_str

Expand All @@ -52,4 +72,6 @@ def load_checkpoint(checkpoint_path):
if checkpoint_path.endswith('.lightning.ckpt'):
state_dict = torch.load(checkpoint_path)['state_dict']
return {key[6:]: weight for key, weight in state_dict.items()}
if checkpoint_path.startswith('pretrained_checkpoints'):
return torch.load(checkpoint_path)['model_state_dict']
return torch.load(checkpoint_path)