diff --git a/tests/conftest.py b/tests/conftest.py index d9ce55ff..0036e021 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,36 @@ from typing import Any -import numpy as np import pytest import torch -import torch.distributions.weibull from ase import Atoms from ase.build import bulk, molecule -from ase.spacegroup import crystal from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.testing import SIMSTATE_GENERATORS DEVICE = torch.device("cpu") DTYPE = torch.float64 +def _make_simstate_fixture(name: str) -> pytest.fixture: + """Create a pytest fixture for a sim_state generator.""" + + @pytest.fixture(name=name) + def _fixture() -> ts.SimState: + return SIMSTATE_GENERATORS[name](DEVICE, DTYPE) + + return _fixture + + +# Programmatically generate fixtures for all sim_state generators +for _name in SIMSTATE_GENERATORS: + globals()[_name] = _make_simstate_fixture(_name) + + @pytest.fixture def lj_model() -> LennardJonesModel: """Create a Lennard-Jones model with reasonable parameters for Ar.""" @@ -98,190 +111,6 @@ def si_phonopy_atoms() -> Any: ) -@pytest.fixture -def si_sim_state(si_atoms: Any) -> Any: - """Create a basic state from si_structure.""" - return ts.io.atoms_to_state(si_atoms, DEVICE, DTYPE) - - -@pytest.fixture -def cu_sim_state() -> ts.SimState: - """Create crystalline copper using ASE.""" - atoms = bulk("Cu", "fcc", a=3.58, cubic=True) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def mg_sim_state() -> ts.SimState: - """Create crystalline magnesium using ASE.""" - atoms = bulk("Mg", "hcp", a=3.17, c=5.14) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def sb_sim_state() -> ts.SimState: - """Create crystalline antimony using ASE.""" - atoms = bulk("Sb", "rhombohedral", a=4.58, alpha=60) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def ti_sim_state() -> ts.SimState: - """Create crystalline titanium using ASE.""" - atoms = bulk("Ti", "hcp", a=2.94, c=4.64) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def tio2_sim_state() -> ts.SimState: - """Create crystalline TiO2 using ASE.""" - a, c = 4.60, 2.96 - basis = [("Ti", 0.5, 0.5, 0), ("O", 0.695679, 0.695679, 0.5)] - atoms = crystal( - symbols=[b[0] for b in basis], - basis=[b[1:] for b in basis], - spacegroup=136, # P4_2/mnm - cellpar=[a, a, c, 90, 90, 90], - ) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def ga_sim_state() -> ts.SimState: - """Create crystalline Ga using ASE.""" - a, b, c = 4.43, 7.60, 4.56 - basis = [("Ga", 0, 0.344304, 0.415401)] - atoms = crystal( - symbols=[b[0] for b in basis], - basis=[b[1:] for b in basis], - spacegroup=64, # Cmce - cellpar=[a, b, c, 90, 90, 90], - ) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def niti_sim_state() -> ts.SimState: - """Create crystalline NiTi using ASE.""" - a, b, c = 2.89, 3.97, 4.83 - alpha, beta, gamma = 90.00, 105.23, 90.00 - basis = [ - ("Ni", 0.369548, 0.25, 0.217074), - ("Ti", 0.076622, 0.25, 0.671102), - ] - atoms = crystal( - symbols=[b[0] for b in basis], - basis=[b[1:] for b in basis], - spacegroup=11, - cellpar=[a, b, c, alpha, beta, gamma], - ) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def sio2_sim_state() -> ts.SimState: - """Create an alpha-quartz SiO2 system for testing.""" - atoms = crystal( - symbols=["O", "Si"], - basis=[[0.413, 0.2711, 0.2172], [0.4673, 0, 0.3333]], - spacegroup=152, - cellpar=[4.9019, 4.9019, 5.3988, 90, 90, 120], - ) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def rattled_sio2_sim_state(sio2_sim_state: ts.SimState) -> ts.SimState: - """Create a rattled SiO2 system for testing.""" - sim_state = sio2_sim_state.clone() - - # Store the current RNG state - rng_state = torch.random.get_rng_state() - try: - # Temporarily set a fixed seed - torch.manual_seed(3) - weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) - rnd = torch.randn_like(sim_state.positions, device=DEVICE, dtype=DTYPE) - rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True).to(device=DEVICE) - shifts = weibull.sample(rnd.shape).to(device=DEVICE) * rnd - sim_state.positions = sim_state.positions + shifts - finally: - # Restore the original RNG state - torch.random.set_rng_state(rng_state) - - return sim_state - - -@pytest.fixture -def rattled_si_sim_state(si_sim_state: ts.SimState) -> ts.SimState: - """Create a rattled Si system for testing.""" - sim_state = si_sim_state.clone() - - # Store the current RNG state - rng_state = torch.random.get_rng_state() - try: - # Temporarily set a fixed seed - torch.manual_seed(3) - weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) - rnd = torch.randn_like(sim_state.positions, device=DEVICE, dtype=DTYPE) - rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True).to(device=DEVICE) - shifts = weibull.sample(rnd.shape).to(device=DEVICE) * rnd - sim_state.positions = sim_state.positions + shifts - finally: - # Restore the original RNG state - torch.random.set_rng_state(rng_state) - - return sim_state - - -@pytest.fixture -def casio3_sim_state() -> ts.SimState: - a, b, c = 7.9258, 7.3202, 7.0653 - alpha, beta, gamma = 90.055, 95.217, 103.426 - basis = [ - ("Ca", 0.19831, 0.42266, 0.76060), - ("Ca", 0.20241, 0.92919, 0.76401), - ("Ca", 0.50333, 0.75040, 0.52691), - ("Si", 0.1851, 0.3875, 0.2684), - ("Si", 0.1849, 0.9542, 0.2691), - ("Si", 0.3973, 0.7236, 0.0561), - ("O", 0.3034, 0.4616, 0.4628), - ("O", 0.3014, 0.9385, 0.4641), - ("O", 0.5705, 0.7688, 0.1988), - ("O", 0.9832, 0.3739, 0.2655), - ("O", 0.9819, 0.8677, 0.2648), - ("O", 0.4018, 0.7266, 0.8296), - ("O", 0.2183, 0.1785, 0.2254), - ("O", 0.2713, 0.8704, 0.0938), - ("O", 0.2735, 0.5126, 0.0931), - ] - atoms = crystal( - symbols=[b[0] for b in basis], - basis=[b[1:] for b in basis], - spacegroup=2, - cellpar=[a, b, c, alpha, beta, gamma], - ) - return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) - - -@pytest.fixture -def benzene_sim_state(benzene_atoms: Any) -> Any: - """Create a basic state from benzene_atoms.""" - return ts.io.atoms_to_state(benzene_atoms, DEVICE, DTYPE) - - -@pytest.fixture -def fe_supercell_sim_state(fe_atoms: Atoms) -> Any: - """Create a face-centered cubic (FCC) iron structure with 4x4x4 supercell.""" - return ts.io.atoms_to_state(fe_atoms.repeat([4, 4, 4]), DEVICE, DTYPE) - - -@pytest.fixture -def ar_supercell_sim_state(ar_atoms: Atoms) -> ts.SimState: - """Create a face-centered cubic (FCC) Argon structure with 2x2x2 supercell.""" - return ts.io.atoms_to_state(ar_atoms.repeat([2, 2, 2]), DEVICE, DTYPE) - - @pytest.fixture def ar_double_sim_state(ar_supercell_sim_state: ts.SimState) -> ts.SimState: """Create a batched state from ar_fcc_sim_state.""" @@ -292,9 +121,12 @@ def ar_double_sim_state(ar_supercell_sim_state: ts.SimState) -> ts.SimState: @pytest.fixture -def si_double_sim_state(si_atoms: Atoms) -> Any: +def si_double_sim_state(si_sim_state: ts.SimState) -> ts.SimState: """Create a basic state from si_structure.""" - return ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, DTYPE) + return ts.concatenate_states( + [si_sim_state, si_sim_state], + device=si_sim_state.device, + ) @pytest.fixture @@ -306,40 +138,3 @@ def mixed_double_sim_state( [ar_supercell_sim_state, si_sim_state], device=ar_supercell_sim_state.device, ) - - -@pytest.fixture -def osn2_sim_state() -> ts.SimState: - """Provides an initial SimState for rhombohedral OsN2.""" - # For pymatgen Structure initialization - from pymatgen.core import Lattice, Structure - - a = 3.211996 - lattice = Lattice.from_parameters(a, a, a, 60, 60, 60) - species = ["Os", "N"] - frac_coords = [[0.75, 0.7501, -0.25], [0, 0, 0]] # Slightly perturbed - structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False) - return ts.initialize_state(structure, dtype=DTYPE, device=DEVICE) - - -@pytest.fixture -def distorted_fcc_al_conventional_sim_state() -> ts.SimState: - """Initial SimState for a slightly distorted FCC Al conventional cell (4 atoms).""" - # Create a standard 4-atom conventional FCC Al cell - atoms_fcc = bulk("Al", crystalstructure="fcc", a=4.05, cubic=True) - - # Define a small triclinic strain matrix (deviations from identity) - strain_matrix = np.array([[1.0, 0.05, -0.03], [0.04, 1.0, 0.06], [-0.02, 0.03, 1.0]]) - - original_cell = atoms_fcc.get_cell() - new_cell = original_cell @ strain_matrix.T # Apply strain - atoms_fcc.set_cell(new_cell, scale_atoms=True) - - # Slightly perturb atomic positions to break perfect symmetry after strain - positions = atoms_fcc.get_positions() - np_rng = np.random.default_rng(seed=42) - positions += np_rng.normal(scale=0.01, size=positions.shape) - atoms_fcc.positions = positions - - # Convert the ASE Atoms object to SimState (will be a single batch with 4 atoms) - return ts.io.atoms_to_state(atoms_fcc, device=DEVICE, dtype=DTYPE) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index ca1add6e..7679224c 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -1,39 +1,18 @@ +"""Pytest fixtures and test factories for model testing.""" + import typing -from typing import Final import pytest import torch -import torch_sim as ts -from tests.conftest import DEVICE -from torch_sim.elastic import full_3x3_to_voigt_6_stress +from tests.conftest import DEVICE, DTYPE +from torch_sim.testing import SIMSTATE_GENERATORS, assert_model_calculator_consistency if typing.TYPE_CHECKING: - from ase.calculators.calculator import Calculator - from torch_sim.models.interface import ModelInterface -consistency_test_simstate_fixtures: Final[tuple[str, ...]] = ( - "cu_sim_state", - "mg_sim_state", - "sb_sim_state", - "tio2_sim_state", - "ga_sim_state", - "niti_sim_state", - "ti_sim_state", - "si_sim_state", - "rattled_si_sim_state", - "sio2_sim_state", - "rattled_sio2_sim_state", - "ar_supercell_sim_state", - "fe_supercell_sim_state", - "casio3_sim_state", - "benzene_sim_state", -) - - def make_model_calculator_consistency_test( test_name: str, model_fixture_name: str, @@ -51,82 +30,55 @@ def make_model_calculator_consistency_test( """Factory function to create model-calculator consistency tests. Args: - test_name: Name of the test (used in the function name and messages) + test_name: Name of the test (used in the function name) model_fixture_name: Name of the model fixture calculator_fixture_name: Name of the calculator fixture sim_state_names: sim_state fixture names to test + device: Device to run tests on + dtype: Data type to use for tests energy_rtol: Relative tolerance for energy comparisons energy_atol: Absolute tolerance for energy comparisons force_rtol: Relative tolerance for force comparisons force_atol: Absolute tolerance for force comparisons stress_rtol: Relative tolerance for stress comparisons stress_atol: Absolute tolerance for stress comparisons + + Returns: + A pytest test function that can be assigned to a module-level variable """ @pytest.mark.parametrize("sim_state_name", sim_state_names) - def test_model_calculator_consistency( + def _model_calculator_consistency_test( sim_state_name: str, request: pytest.FixtureRequest ) -> None: """Test consistency between model and calculator implementations.""" - # Get the model and calculator fixtures dynamically - model: ModelInterface = request.getfixturevalue(model_fixture_name) - calculator: Calculator = request.getfixturevalue(calculator_fixture_name) - - # Get the sim_state fixture dynamically using the name - sim_state: ts.SimState = request.getfixturevalue(sim_state_name).to(device, dtype) - - # Set up ASE calculator - atoms = ts.io.state_to_atoms(sim_state)[0] - atoms.calc = calculator - - # Get model results - model_results = model(sim_state) - - # Get calculator results - calc_forces = torch.tensor( - atoms.get_forces(), - device=device, - dtype=model_results["forces"].dtype, - ) - - # Test consistency with specified tolerances - torch.testing.assert_close( - model_results["energy"].item(), - atoms.get_potential_energy(), - rtol=energy_rtol, - atol=energy_atol, + model = request.getfixturevalue(model_fixture_name) + calculator = request.getfixturevalue(calculator_fixture_name) + + # Generate sim_state from the generator + generator = SIMSTATE_GENERATORS[sim_state_name] + sim_state = generator(device, dtype) + + assert_model_calculator_consistency( + model=model, + calculator=calculator, + sim_state=sim_state, + energy_rtol=energy_rtol, + energy_atol=energy_atol, + force_rtol=force_rtol, + force_atol=force_atol, + stress_rtol=stress_rtol, + stress_atol=stress_atol, ) - torch.testing.assert_close( - model_results["forces"], - calc_forces, - rtol=force_rtol, - atol=force_atol, - ) - - if "stress" in model_results: - calc_stress = torch.tensor( - atoms.get_stress(), - device=device, - dtype=model_results["stress"].dtype, - ).unsqueeze(0) - torch.testing.assert_close( - full_3x3_to_voigt_6_stress(model_results["stress"]), - calc_stress, - rtol=stress_rtol, - atol=stress_atol, - equal_nan=True, - ) - - # Rename the function to include the test name - test_model_calculator_consistency.__name__ = f"test_{test_name}_consistency" - return test_model_calculator_consistency + _model_calculator_consistency_test.__name__ = f"test_{test_name}_consistency" + return _model_calculator_consistency_test def make_validate_model_outputs_test( model_fixture_name: str, device: torch.device = DEVICE, - dtype: torch.dtype = torch.float64, + dtype: torch.dtype = DTYPE, ): """Factory function to create model output validation tests. diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py index 933ba002..02b642ab 100644 --- a/tests/models/test_fairchem_legacy.py +++ b/tests/models/test_fairchem_legacy.py @@ -7,10 +7,10 @@ import torch_sim as ts from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) +from torch_sim.testing import SIMSTATE_BULK_GENERATORS, SIMSTATE_MOLECULE_GENERATORS try: @@ -69,7 +69,7 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: test_name="fairchem_ocp", model_fixture_name="eqv2_oc20_model_pbc", calculator_fixture_name="ocp_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], + sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models energy_atol=5e-4, force_rtol=5e-4, @@ -78,11 +78,11 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: stress_atol=5e-4, ) -test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( +test_fairchem_non_pbc = make_model_calculator_consistency_test( test_name="fairchem_non_pbc_benzene", model_fixture_name="eqv2_oc20_model_non_pbc", calculator_fixture_name="ocp_calculator", - sim_state_names=["benzene_sim_state"], + sim_state_names=tuple(SIMSTATE_MOLECULE_GENERATORS.keys()), energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models energy_atol=5e-4, force_rtol=5e-4, diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index ff315e11..276b0fc5 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -7,11 +7,11 @@ import torch_sim as ts from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) from torch_sim.models.graphpes import GraphPESWrapper +from torch_sim.testing import CONSISTENCY_SIMSTATES try: @@ -141,7 +141,7 @@ def ase_nequip_calculator(): test_name="graphpes-nequip", model_fixture_name="ts_nequip_model", calculator_fixture_name="ase_nequip_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=CONSISTENCY_SIMSTATES, device=DEVICE, dtype=DTYPE, energy_rtol=1e-3, @@ -176,7 +176,7 @@ def ase_mace_calculator(): test_name="graphpes-mace", model_fixture_name="ts_mace_model", calculator_fixture_name="ase_mace_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=CONSISTENCY_SIMSTATES, device=DEVICE, dtype=DTYPE, ) @@ -205,7 +205,7 @@ def ase_lj_calculator(): test_name="graphpes-lj", model_fixture_name="ts_lj_model", calculator_fixture_name="ase_lj_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=CONSISTENCY_SIMSTATES, device=DEVICE, dtype=DTYPE, ) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 50f89697..ef7dc60d 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -7,25 +7,33 @@ import torch_sim as ts from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) from torch_sim.models.mace import MaceUrls +from torch_sim.testing import SIMSTATE_BULK_GENERATORS, SIMSTATE_MOLECULE_GENERATORS try: from mace.calculators import MACECalculator - from mace.calculators.foundations_models import mace_mp, mace_off, mace_omol + from mace.calculators.foundations_models import mace_mp, mace_off from torch_sim.models.mace import MaceModel except (ImportError, ValueError): pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) +# mace_omol is optional (added in newer MACE versions) +try: + from mace.calculators.foundations_models import mace_omol + + raw_mace_omol = mace_omol(model="extra_large", return_raw_model=True) + HAS_MACE_OMOL = True +except ImportError: + raw_mace_omol = None + HAS_MACE_OMOL = False raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) raw_mace_off = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) -raw_mace_omol = mace_omol(model="extra_large", return_raw_model=True) DTYPE = torch.float64 @@ -52,18 +60,7 @@ def ts_mace_model() -> MaceModel: test_name="mace", model_fixture_name="ts_mace_model", calculator_fixture_name="ase_mace_calculator", - sim_state_names=tuple( - s for s in consistency_test_simstate_fixtures if s != "ti_sim_state" - ), - dtype=DTYPE, -) - - -test_mace_consistency_ti = make_model_calculator_consistency_test( - test_name="mace_ti", - model_fixture_name="ts_mace_model", - calculator_fixture_name="ase_mace_calculator", - sim_state_names=("ti_sim_state",), + sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), dtype=DTYPE, ) @@ -81,21 +78,6 @@ def test_mace_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: model.forward(state) -@pytest.fixture -def benzene_system(benzene_atoms: Atoms) -> dict: - atomic_numbers = benzene_atoms.get_atomic_numbers() - - positions = torch.tensor(benzene_atoms.positions, device=DEVICE, dtype=DTYPE) - cell = torch.tensor(benzene_atoms.cell.array, device=DEVICE, dtype=DTYPE) - - return { - "positions": positions, - "cell": cell, - "atomic_numbers": atomic_numbers, - "ase_atoms": benzene_atoms, - } - - @pytest.fixture def ase_mace_off_calculator() -> MACECalculator: return mace_off( @@ -115,7 +97,7 @@ def ts_mace_off_model() -> MaceModel: test_name="mace_off", model_fixture_name="ts_mace_off_model", calculator_fixture_name="ase_mace_off_calculator", - sim_state_names=("benzene_sim_state",), + sim_state_names=tuple(SIMSTATE_MOLECULE_GENERATORS.keys()), dtype=DTYPE, ) @@ -125,12 +107,11 @@ def ts_mace_off_model() -> MaceModel: @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_mace_off_dtype_working(benzene_atoms: Atoms, dtype: torch.dtype) -> None: +def test_mace_off_dtype_working( + benzene_sim_state: ts.SimState, dtype: torch.dtype +) -> None: model = MaceModel(model=raw_mace_off, device=DEVICE, dtype=dtype, compute_forces=True) - - state = ts.io.atoms_to_state([benzene_atoms], DEVICE, dtype) - - model.forward(state) + model.forward(benzene_sim_state.to(DEVICE, dtype)) def test_mace_urls_enum() -> None: @@ -140,6 +121,7 @@ def test_mace_urls_enum() -> None: assert key.value.endswith((".model", ".model?raw=true")) +@pytest.mark.skipif(not HAS_MACE_OMOL, reason="mace_omol not available") @pytest.mark.parametrize( ("charge", "spin"), [ @@ -149,27 +131,28 @@ def test_mace_urls_enum() -> None: (0.0, 2.0), # Neutral, spin=2 (triplet) ], ) -def test_mace_charge_spin(benzene_atoms: Atoms, charge: float, spin: float) -> None: +def test_mace_charge_spin( + benzene_sim_state: ts.SimState, charge: float, spin: float +) -> None: """Test that MaceModel correctly handles charge and spin from atoms.info.""" - # Set charge and spin in ASE atoms.info - benzene_atoms.info["charge"] = charge - benzene_atoms.info["spin"] = spin - # Convert to SimState (should extract charge/spin) - state = ts.io.atoms_to_state([benzene_atoms], DEVICE, DTYPE) + benzene_sim_state.charge = torch.tensor([charge], device=DEVICE, dtype=DTYPE) + benzene_sim_state.spin = torch.tensor([spin], device=DEVICE, dtype=DTYPE) # Verify charge/spin were extracted correctly if charge != 0.0: - assert state.charge is not None - assert state.charge[0].item() == charge + assert benzene_sim_state.charge is not None + assert benzene_sim_state.charge[0].item() == charge else: - assert state.charge is None or state.charge[0].item() == 0.0 + assert ( + benzene_sim_state.charge is None or benzene_sim_state.charge[0].item() == 0.0 + ) if spin != 0.0: - assert state.spin is not None - assert state.spin[0].item() == spin + assert benzene_sim_state.spin is not None + assert benzene_sim_state.spin[0].item() == spin else: - assert state.spin is None or state.spin[0].item() == 0.0 + assert benzene_sim_state.spin is None or benzene_sim_state.spin[0].item() == 0.0 # Create model with MACE-OMOL (supports charge/spin for molecules) model = MaceModel( @@ -180,11 +163,11 @@ def test_mace_charge_spin(benzene_atoms: Atoms, charge: float, spin: float) -> N ) # This should not raise an error - result = model.forward(state) + result = model.forward(benzene_sim_state) # Verify outputs exist assert "energy" in result assert result["energy"].shape == (1,) if model.compute_forces: assert "forces" in result - assert result["forces"].shape == benzene_atoms.positions.shape + assert result["forces"].shape == benzene_sim_state.positions.shape diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index e97f277f..e53a5cdc 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -8,10 +8,10 @@ from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) +from torch_sim.testing import SIMSTATE_GENERATORS try: @@ -62,7 +62,7 @@ def test_mattersim_initialization(pretrained_mattersim_model: Potential) -> None test_name="mattersim", model_fixture_name="mattersim_model", calculator_fixture_name="mattersim_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), ) test_mattersim_model_outputs = make_validate_model_outputs_test( diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index d7e2247a..034c1228 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -5,10 +5,10 @@ from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) +from torch_sim.testing import SIMSTATE_GENERATORS try: @@ -51,7 +51,7 @@ def test_metatomic_initialization() -> None: test_name="metatomic", model_fixture_name="metatomic_model", calculator_fixture_name="metatomic_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), energy_atol=5e-5, dtype=torch.float32, device=DEVICE, diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 265d1bb0..8e9a4000 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -6,10 +6,10 @@ from tests.conftest import DEVICE, DTYPE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) +from torch_sim.testing import SIMSTATE_BULK_GENERATORS try: @@ -122,7 +122,7 @@ def nequip_calculator(compiled_ase_nequip_model_path: Path) -> NequIPCalculator: test_name="nequip", model_fixture_name="nequip_model", calculator_fixture_name="nequip_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], + sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), energy_atol=5e-5, dtype=DTYPE, device=DEVICE, diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 31901e16..118f6d30 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -6,12 +6,12 @@ from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) from torch_sim import SimState from torch_sim.models.orb import cell_to_cellpar +from torch_sim.testing import SIMSTATE_GENERATORS try: @@ -57,7 +57,7 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: test_name="orbv3_conservative_inf_omat", model_fixture_name="orbv3_conservative_inf_omat_model", calculator_fixture_name="orbv3_conservative_inf_omat_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), energy_rtol=5e-5, energy_atol=5e-5, ) @@ -66,7 +66,7 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: test_name="orbv3_direct_20_omat", model_fixture_name="orbv3_direct_20_omat_model", calculator_fixture_name="orbv3_direct_20_omat_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), energy_rtol=5e-5, energy_atol=5e-5, ) diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 1e373b79..91a8aca7 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -5,10 +5,10 @@ from tests.conftest import DEVICE from tests.models.conftest import ( - consistency_test_simstate_fixtures, make_model_calculator_consistency_test, make_validate_model_outputs_test, ) +from torch_sim.testing import SIMSTATE_BULK_GENERATORS try: @@ -60,13 +60,13 @@ def test_sevennet_initialization(pretrained_sevenn_model: AtomGraphSequential) - assert model.device == DEVICE -# NOTE: we take [:-1] to skipbenzene due to eps volume giving numerically +# NOTE: we do not test on the molecule sim states due to eps volume giving numerically # unstable stress off diagonal in xy. See: https://github.com/MDIL-SNU/SevenNet/issues/212 test_sevennet_consistency = make_model_calculator_consistency_test( test_name="sevennet", model_fixture_name="sevenn_model", calculator_fixture_name="sevenn_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], + sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), dtype=DTYPE, ) diff --git a/torch_sim/io.py b/torch_sim/io.py index fdd75893..6c400780 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -189,7 +189,9 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: def atoms_to_state( - atoms: "Atoms | list[Atoms]", device: torch.device, dtype: torch.dtype + atoms: "Atoms | list[Atoms]", + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> "ts.SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. @@ -264,7 +266,9 @@ def atoms_to_state( def structures_to_state( - structure: "Structure | list[Structure]", device: torch.device, dtype: torch.dtype + structure: "Structure | list[Structure]", + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> "ts.SimState": """Create a SimState from pymatgen Structure(s). @@ -335,8 +339,8 @@ def structures_to_state( def phonopy_to_state( phonopy_atoms: "PhonopyAtoms | list[PhonopyAtoms]", - device: torch.device, - dtype: torch.dtype, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> "ts.SimState": """Create state tensors from a PhonopyAtoms object or list of PhonopyAtoms objects. diff --git a/torch_sim/state.py b/torch_sim/state.py index ec2c589d..318454c7 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1061,8 +1061,8 @@ def concatenate_states[T: SimState]( # noqa: C901 def initialize_state( system: StateLike, - device: torch.device, - dtype: torch.dtype, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> SimState: """Initialize state tensors from a atomistic system representation. diff --git a/torch_sim/testing.py b/torch_sim/testing.py new file mode 100644 index 00000000..8ba6589e --- /dev/null +++ b/torch_sim/testing.py @@ -0,0 +1,406 @@ +"""Testing utilities for torch-sim models. + +This module provides reusable testing functions and SimState generators that can be +used to validate model implementations. These are designed to work both within +torch-sim's test suite and in external repositories that implement ModelInterface +models. + +Example usage in another repo:: + + import pytest + import torch + from torch_sim.testing import ( + assert_model_calculator_consistency, + SIMSTATE_GENERATORS, + CONSISTENCY_SIMSTATES, + ) + + DEVICE = torch.device("cpu") + DTYPE = torch.float64 + + + @pytest.mark.parametrize("sim_state_name", CONSISTENCY_SIMSTATES) + def test_my_model_consistency(sim_state_name, my_model, my_calculator): + sim_state = SIMSTATE_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency(my_model, my_calculator, sim_state) +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Final + +import torch + +import torch_sim as ts +from torch_sim.elastic import full_3x3_to_voigt_6_stress + + +if TYPE_CHECKING: + from ase.calculators.calculator import Calculator + + from torch_sim.models.interface import ModelInterface + + +def make_cu_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline copper (FCC).""" + from ase.build import bulk + + atoms = bulk("Cu", "fcc", a=3.58, cubic=True) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_mg_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline magnesium (HCP).""" + from ase.build import bulk + + atoms = bulk("Mg", "hcp", a=3.17, c=5.14) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_sb_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline antimony (rhombohedral).""" + from ase.build import bulk + + atoms = bulk("Sb", "rhombohedral", a=4.58, alpha=60) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_ti_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline titanium (HCP).""" + from ase.build import bulk + + atoms = bulk("Ti", "hcp", a=2.94, c=4.64) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_tio2_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline TiO2 (rutile).""" + from ase.spacegroup import crystal + + a, c = 4.60, 2.96 + basis = [("Ti", 0.5, 0.5, 0), ("O", 0.695679, 0.695679, 0.5)] + atoms = crystal( + symbols=[b[0] for b in basis], + basis=[b[1:] for b in basis], + spacegroup=136, + cellpar=[a, a, c, 90, 90, 90], + ) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_ga_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline Ga (Cmce).""" + from ase.spacegroup import crystal + + a, b, c = 4.43, 7.60, 4.56 + basis = [("Ga", 0, 0.344304, 0.415401)] + atoms = crystal( + symbols=[ba[0] for ba in basis], + basis=[ba[1:] for ba in basis], + spacegroup=64, + cellpar=[a, b, c, 90, 90, 90], + ) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_niti_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline NiTi (monoclinic).""" + from ase.spacegroup import crystal + + a, b, c = 2.89, 3.97, 4.83 + alpha, beta, gamma = 90.00, 105.23, 90.00 + basis = [ + ("Ni", 0.369548, 0.25, 0.217074), + ("Ti", 0.076622, 0.25, 0.671102), + ] + atoms = crystal( + symbols=[ba[0] for ba in basis], + basis=[ba[1:] for ba in basis], + spacegroup=11, + cellpar=[a, b, c, alpha, beta, gamma], + ) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_si_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create crystalline silicon (diamond).""" + from ase.build import bulk + + atoms = bulk("Si", "diamond", a=5.43, cubic=True) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_sio2_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create alpha-quartz SiO2.""" + from ase.spacegroup import crystal + + atoms = crystal( + symbols=["O", "Si"], + basis=[[0.413, 0.2711, 0.2172], [0.4673, 0, 0.3333]], + spacegroup=152, + cellpar=[4.9019, 4.9019, 5.3988, 90, 90, 120], + ) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def _rattle_sim_state(sim_state: ts.SimState, seed: int = 3) -> ts.SimState: + """Apply Weibull-distributed random displacements to positions.""" + sim_state = sim_state.clone() + rng_state = torch.random.get_rng_state() + try: + torch.manual_seed(seed) + weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) + rnd = torch.randn_like(sim_state.positions) + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) + shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd + sim_state.positions = sim_state.positions + shifts + finally: + torch.random.set_rng_state(rng_state) + return sim_state + + +def make_rattled_si_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create rattled silicon.""" + return _rattle_sim_state(make_si_sim_state(device, dtype)) + + +def make_rattled_sio2_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create rattled alpha-quartz SiO2.""" + return _rattle_sim_state(make_sio2_sim_state(device, dtype)) + + +def make_ar_supercell_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create FCC Argon 2x2x2 supercell.""" + from ase.build import bulk + + atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2]) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_fe_supercell_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create FCC iron 4x4x4 supercell.""" + from ase.build import bulk + + atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([4, 4, 4]) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_casio3_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create CaSiO3 (wollastonite).""" + from ase.spacegroup import crystal + + a, b, c = 7.9258, 7.3202, 7.0653 + alpha, beta, gamma = 90.055, 95.217, 103.426 + basis = [ + ("Ca", 0.19831, 0.42266, 0.76060), + ("Ca", 0.20241, 0.92919, 0.76401), + ("Ca", 0.50333, 0.75040, 0.52691), + ("Si", 0.1851, 0.3875, 0.2684), + ("Si", 0.1849, 0.9542, 0.2691), + ("Si", 0.3973, 0.7236, 0.0561), + ("O", 0.3034, 0.4616, 0.4628), + ("O", 0.3014, 0.9385, 0.4641), + ("O", 0.5705, 0.7688, 0.1988), + ("O", 0.9832, 0.3739, 0.2655), + ("O", 0.9819, 0.8677, 0.2648), + ("O", 0.4018, 0.7266, 0.8296), + ("O", 0.2183, 0.1785, 0.2254), + ("O", 0.2713, 0.8704, 0.0938), + ("O", 0.2735, 0.5126, 0.0931), + ] + atoms = crystal( + symbols=[ba[0] for ba in basis], + basis=[ba[1:] for ba in basis], + spacegroup=2, + cellpar=[a, b, c, alpha, beta, gamma], + ) + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_benzene_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create benzene molecule (non-periodic).""" + from ase.build import molecule + + atoms = molecule("C6H6") + return ts.io.atoms_to_state(atoms, device, dtype) + + +def make_osn2_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create rhombohedral OsN2.""" + import numpy as np + from ase import Atoms + from ase.geometry import cellpar_to_cell + + a = 3.211996 + atoms = Atoms( + symbols=["Os", "N"], + scaled_positions=[[0.75, 0.7501, -0.25], [0, 0, 0]], + cell=np.roll(cellpar_to_cell([a, a, a, 60, 60, 60]), -1, axis=(0, 1)), + pbc=True, + ) + return ts.io.atoms_to_state(atoms, device=device, dtype=dtype) + + +def make_distorted_fcc_al_conventional_sim_state( + device: torch.device | None = None, dtype: torch.dtype | None = None +) -> ts.SimState: + """Create a slightly distorted FCC Al conventional cell (4 atoms).""" + import numpy as np + from ase.build import bulk + + atoms_fcc = bulk("Al", crystalstructure="fcc", a=4.05, cubic=True) + + strain_matrix = np.array([[1.0, 0.05, -0.03], [0.04, 1.0, 0.06], [-0.02, 0.03, 1.0]]) + original_cell = atoms_fcc.get_cell() + new_cell = original_cell @ strain_matrix.T + atoms_fcc.set_cell(new_cell, scale_atoms=True) + + positions = atoms_fcc.get_positions() + np_rng = np.random.default_rng(seed=42) + positions += np_rng.normal(scale=0.01, size=positions.shape) + atoms_fcc.positions = positions + + return ts.io.atoms_to_state(atoms_fcc, device=device, dtype=dtype) + + +# Generator type alias +SimStateGenerator = Callable[[torch.device, torch.dtype], ts.SimState] + +# Dict mapping names to generator functions +SIMSTATE_BULK_GENERATORS: Final[dict[str, SimStateGenerator]] = { + "cu_sim_state": make_cu_sim_state, + "mg_sim_state": make_mg_sim_state, + "sb_sim_state": make_sb_sim_state, + "tio2_sim_state": make_tio2_sim_state, + "ga_sim_state": make_ga_sim_state, + "niti_sim_state": make_niti_sim_state, + "ti_sim_state": make_ti_sim_state, + "si_sim_state": make_si_sim_state, + "rattled_si_sim_state": make_rattled_si_sim_state, + "sio2_sim_state": make_sio2_sim_state, + "rattled_sio2_sim_state": make_rattled_sio2_sim_state, + "ar_supercell_sim_state": make_ar_supercell_sim_state, + "fe_supercell_sim_state": make_fe_supercell_sim_state, + "casio3_sim_state": make_casio3_sim_state, + "osn2_sim_state": make_osn2_sim_state, + "distorted_fcc_al_conventional_sim_state": ( + make_distorted_fcc_al_conventional_sim_state + ), +} + +SIMSTATE_MOLECULE_GENERATORS: Final[ + dict[str, Callable[[torch.device, torch.dtype], ts.SimState]] +] = { + "benzene_sim_state": make_benzene_sim_state, +} + + +SIMSTATE_GENERATORS: Final[dict[str, SimStateGenerator]] = { + **SIMSTATE_BULK_GENERATORS, + **SIMSTATE_MOLECULE_GENERATORS, +} + +# Tuple of names for backward compat / parametrize usage +CONSISTENCY_SIMSTATES: Final[tuple[str, ...]] = tuple(SIMSTATE_GENERATORS.keys()) + + +def assert_model_calculator_consistency( + model: "ModelInterface", + calculator: "Calculator", + sim_state: ts.SimState, + energy_rtol: float = 1e-5, + energy_atol: float = 1e-5, + force_rtol: float = 1e-5, + force_atol: float = 1e-5, + stress_rtol: float = 1e-5, + stress_atol: float = 1e-5, +) -> None: + """Assert consistency between model and calculator implementations. + + This function validates that a ModelInterface implementation produces + the same results as an ASE Calculator implementation for a given + simulation state. It compares energies, forces, and optionally stresses. + + Args: + model: ModelInterface instance to test + calculator: ASE Calculator instance to compare against + sim_state: Simulation state to test with + energy_rtol: Relative tolerance for energy comparisons + energy_atol: Absolute tolerance for energy comparisons + force_rtol: Relative tolerance for force comparisons + force_atol: Absolute tolerance for force comparisons + stress_rtol: Relative tolerance for stress comparisons + stress_atol: Absolute tolerance for stress comparisons + + Raises: + AssertionError: If model and calculator results don't match within tolerances + """ + atoms = ts.io.state_to_atoms(sim_state)[0] + atoms.calc = calculator + + model_results = model(sim_state) + + calc_forces = torch.tensor( + atoms.get_forces(), + device=sim_state.positions.device, + dtype=model_results["forces"].dtype, + ) + + torch.testing.assert_close( + model_results["energy"].item(), + atoms.get_potential_energy(), + rtol=energy_rtol, + atol=energy_atol, + ) + torch.testing.assert_close( + model_results["forces"], + calc_forces, + rtol=force_rtol, + atol=force_atol, + ) + + if "stress" in model_results: + calc_stress = torch.tensor( + atoms.get_stress(), + device=sim_state.positions.device, + dtype=model_results["stress"].dtype, + ).unsqueeze(0) + + torch.testing.assert_close( + full_3x3_to_voigt_6_stress(model_results["stress"]), + calc_stress, + rtol=stress_rtol, + atol=stress_atol, + equal_nan=True, + )