diff --git a/datasets/base.py b/datasets/base.py index 5f23860c..eaac8dfc 100644 --- a/datasets/base.py +++ b/datasets/base.py @@ -12,7 +12,7 @@ from platforms.platform import get_platform -class Dataset(abc.ABC, torch.utils.data.Dataset): +class Dataset(torch.utils.data.Dataset, abc.ABC): """The base class for all datasets in this framework.""" @staticmethod diff --git a/models/base.py b/models/base.py index 6d277caf..a8cb5f92 100644 --- a/models/base.py +++ b/models/base.py @@ -13,7 +13,7 @@ from platforms.platform import get_platform -class Model(abc.ABC, torch.nn.Module): +class Model(torch.nn.Module, abc.ABC): """The base class used by all models in this codebase.""" @staticmethod diff --git a/platforms/base.py b/platforms/base.py index d3658eaf..785ae468 100644 --- a/platforms/base.py +++ b/platforms/base.py @@ -24,14 +24,7 @@ class Platform(Hparams): @property def device_str(self): - # GPU device. - if torch.cuda.is_available() and torch.cuda.device_count() > 0: - device_ids = ','.join([str(x) for x in range(torch.cuda.device_count())]) - return f'cuda:{device_ids}' - - # CPU device. - else: - return 'cpu' + return 'cuda' if torch.cuda.is_available() else 'cpu' @property def torch_device(self): diff --git a/pruning/base.py b/pruning/base.py index 559ede5b..3678e689 100644 --- a/pruning/base.py +++ b/pruning/base.py @@ -6,8 +6,11 @@ import abc from foundations.hparams import PruningHparams -from models import base -from pruning.mask import Mask + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from pruning.mask import Mask + from models import base class Strategy(abc.ABC): @@ -18,5 +21,5 @@ def get_pruning_hparams() -> type: @staticmethod @abc.abstractmethod - def prune(pruning_hparams: PruningHparams, trained_model: base.Model, current_mask: Mask = None) -> Mask: + def prune(pruning_hparams: PruningHparams, trained_model: 'base.Model', current_mask: 'Mask' = None) -> 'Mask': pass diff --git a/pruning/mask.py b/pruning/mask.py index 13a5e63f..18ce7e95 100644 --- a/pruning/mask.py +++ b/pruning/mask.py @@ -8,9 +8,13 @@ import torch from foundations import paths -from models import base from platforms.platform import get_platform +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from models import base + + class Mask(dict): def __init__(self, other_dict=None): @@ -30,7 +34,7 @@ def __setitem__(self, key, value): super(Mask, self).__setitem__(key, value) @staticmethod - def ones_like(model: base.Model) -> 'Mask': + def ones_like(model: 'base.Model') -> 'Mask': mask = Mask() for name in model.prunable_layer_names: mask[name] = torch.ones(list(model.state_dict()[name].shape)) diff --git a/pruning/sparse_global.py b/pruning/sparse_global.py index 9da04b60..a7a977d9 100644 --- a/pruning/sparse_global.py +++ b/pruning/sparse_global.py @@ -7,10 +7,13 @@ import numpy as np from foundations import hparams -import models.base from pruning import base from pruning.mask import Mask +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from models import base as models_base + @dataclasses.dataclass class PruningHparams(hparams.PruningHparams): @@ -29,7 +32,7 @@ def get_pruning_hparams() -> type: return PruningHparams @staticmethod - def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None): + def prune(pruning_hparams: PruningHparams, trained_model: 'models_base.Model', current_mask: Mask = None): current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() # Determine the number of weights that need to be pruned.