Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import DataLoader

from config_base import BaseConfig
from config.model_config import BeatGANsUNetConfig, BeatGANsAutoencConfig, MLPSkipNetConfig
from dataset import *
from diffusion import *
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
Expand Down
61 changes: 61 additions & 0 deletions source/config/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from dataclasses import dataclass
from typing import Tuple, Optional

from model.unet import ScaleAt
from model.latentnet import *
from choices import *

@dataclass
class ModelConfigBase:
"""Base class for all model configurations"""
pass

@dataclass
class BeatGANsUNetConfig(ModelConfigBase):
attention_resolutions: Tuple[int]
channel_mult: Tuple[int]
conv_resample: bool
dims: int
dropout: float
embed_channels: int
image_size: int
in_channels: int
model_channels: int
num_classes: Optional[int]
num_head_channels: int
num_heads_upsample: int
num_heads: int
num_res_blocks: int
num_input_res_blocks: Optional[int]
out_channels: int
resblock_updown: bool
use_checkpoint: bool
use_new_attention_order: bool
resnet_two_cond: bool
resnet_use_zero_module: bool
resnet_cond_channels: Optional[int] = None

@dataclass
class MLPSkipNetConfig(ModelConfigBase):
num_channels: int
skip_layers: Tuple[int]
num_hid_channels: int
num_layers: int
num_time_emb_channels: int
activation: Activation
use_norm: bool
condition_bias: float
dropout: float
last_act: Activation
num_time_layers: int
time_last_act: bool

@dataclass
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
enc_out_channels: int
enc_pool: str
enc_num_res_block: int
enc_channel_mult: Tuple[int]
enc_grad_checkpoint: bool
enc_attn_resolutions: Tuple[int]
latent_net_conf: Optional[MLPSkipNetConfig] = None
1 change: 0 additions & 1 deletion source/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def parse_args():
return parser.parse_args()

def load_config(config_path):
# This is a placeholder - you'll need to implement config loading based on your actual config format
from config import TrainConfig
import importlib.util

Expand Down