Skip to content

Commit 2d5835a

Browse files
committed
ruff
1 parent fa522f3 commit 2d5835a

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/load_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
def load_model(
1616
model: str | Path | int,
17-
folds: list[str | int] = ["all"],
17+
folds: list[str | int] | None = None,
1818
) -> Segmentation_Model:
1919
"""
2020
Load a model by its name, path, or version number.
2121
"""
22+
if folds is None:
23+
folds = ["all"]
2224
if isinstance(model, int):
2325
return load_model_by_version(model, folds=folds)
2426
elif isinstance(model, (str, Path)):
@@ -32,19 +34,24 @@ def load_model(
3234

3335
def load_model_by_path(
3436
path_dir: str | Path,
35-
folds: list[str | int] = ["all"],
37+
folds: list[str | int] | None = None,
3638
) -> Segmentation_Model:
3739
"""Load a model from a specified directory."""
40+
if folds is None:
41+
folds = ["all"]
42+
assert folds is not None
3843
return get_actual_model(in_config=path_dir).load(folds=folds)
3944

4045

4146
def load_model_by_version(
4247
version: int,
43-
folds: list[str | int] = ["all"],
48+
folds: list[str | int] | None = None,
4449
) -> Segmentation_Model:
4550
"""
4651
Load a model by its version number.
4752
"""
53+
if folds is None:
54+
folds = ["all"]
4855
modelname = f"Paraside_model_weights_v{version}"
4956
path = f"model_weights/{modelname}"
5057

0 commit comments

Comments
 (0)