-
Notifications
You must be signed in to change notification settings - Fork 94
Open
Description
Right now, we have the following test
@pytest.mark.parametrize("model_variant", MODEL_VARIANTS)
def test_load_legacy_model(model_variant):
"""Test loading a legacy CEBRA model."""
X = np.random.normal(0, 1, (1000, 30))
model_path = pathlib.Path(
__file__
).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt"
if not model_path.exists():
url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt"
urllib.request.urlretrieve(url, model_path)
loaded_model = CEBRA.load(model_path)
assert loaded_model.model_architecture == "offset10-model"
assert loaded_model.output_dimension == 8
assert loaded_model.num_hidden_units == 16
assert loaded_model.time_offsets == 10
output = loaded_model.transform(X)
assert isinstance(output, np.ndarray)
assert output.shape[1] == loaded_model.output_dimension
assert hasattr(loaded_model, "state_dict_")
assert hasattr(loaded_model, "n_features_")This test checks that the models can be loaded, but not that they give the same output. To improve this test, lets
- compute reference outputs of the legacy models and store them in the s3 bucket
- adapt the test to include assert_close checks between the re-computed and original model embeddings
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels