diff --git a/source/config/config.py b/source/config/config.py index 98068e8..601d895 100644 --- a/source/config/config.py +++ b/source/config/config.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader from config_base import BaseConfig +from config.model_config import BeatGANsUNetConfig, BeatGANsAutoencConfig, MLPSkipNetConfig from dataset import * from diffusion import * from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule diff --git a/source/config/model_config.py b/source/config/model_config.py new file mode 100644 index 0000000..72a4def --- /dev/null +++ b/source/config/model_config.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + +from model.unet import ScaleAt +from model.latentnet import * +from choices import * + +@dataclass +class ModelConfigBase: + """Base class for all model configurations""" + pass + +@dataclass +class BeatGANsUNetConfig(ModelConfigBase): + attention_resolutions: Tuple[int] + channel_mult: Tuple[int] + conv_resample: bool + dims: int + dropout: float + embed_channels: int + image_size: int + in_channels: int + model_channels: int + num_classes: Optional[int] + num_head_channels: int + num_heads_upsample: int + num_heads: int + num_res_blocks: int + num_input_res_blocks: Optional[int] + out_channels: int + resblock_updown: bool + use_checkpoint: bool + use_new_attention_order: bool + resnet_two_cond: bool + resnet_use_zero_module: bool + resnet_cond_channels: Optional[int] = None + +@dataclass +class MLPSkipNetConfig(ModelConfigBase): + num_channels: int + skip_layers: Tuple[int] + num_hid_channels: int + num_layers: int + num_time_emb_channels: int + activation: Activation + use_norm: bool + condition_bias: float + dropout: float + last_act: Activation + num_time_layers: int + time_last_act: bool + +@dataclass +class BeatGANsAutoencConfig(BeatGANsUNetConfig): + enc_out_channels: int + enc_pool: str + enc_num_res_block: int + enc_channel_mult: Tuple[int] + enc_grad_checkpoint: bool + enc_attn_resolutions: Tuple[int] + latent_net_conf: Optional[MLPSkipNetConfig] = None diff --git a/source/experiment.py b/source/experiment.py index 51646cd..c9d4c43 100644 --- a/source/experiment.py +++ b/source/experiment.py @@ -19,7 +19,6 @@ def parse_args(): return parser.parse_args() def load_config(config_path): - # This is a placeholder - you'll need to implement config loading based on your actual config format from config import TrainConfig import importlib.util