Skip to content

Commit 8703b9f

Browse files
committed
Always fallback to the DEFAULT_MODEL_FAMILY
The code attempted to fallback to this, but if no model_path was given then we were not actually falling back to DEFAULT_MODEL_FAMILY (merlinite). This ensures we always fallback to it, even if no model_path is given. Signed-off-by: Ben Browning <bbrownin@redhat.com>
1 parent b17b08d commit 8703b9f

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/instructlab/sdg/utils/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,10 @@ def get_model_family(model_family, model_path):
2222
return model_family
2323

2424
# Try to guess the model family based on the model's filename
25-
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
26-
return guess if guess in registry else DEFAULT_MODEL_FAMILY
25+
if model_path:
26+
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
27+
if guess in registry:
28+
return guess
29+
30+
# Nothing was found, so just return the default
31+
return DEFAULT_MODEL_FAMILY

tests/test_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,8 @@ def test_unknown_model_family(self):
6464
"foobar", "./models/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"
6565
)
6666
assert "Unknown model family: foobar" in str(exc.value)
67+
68+
def test_none_args(self):
69+
assert (
70+
models.get_model_family(None, None) == "merlinite"
71+
)

0 commit comments

Comments
 (0)