diff --git a/tests/test_io.py b/tests/test_io.py index a2c25ab4..737517b8 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -5,6 +5,7 @@ import pytest import torch from ase import Atoms +from ase.build import molecule from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure @@ -88,6 +89,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: ) +@pytest.mark.parametrize( + ("charge", "spin", "expected_charge", "expected_spin"), + [ + (1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin + (0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin + (None, None, 0.0, 0.0), # No charge/spin set, defaults to zero + ], +) +def test_atoms_to_state_with_charge_spin( + charge: float | None, + spin: float | None, + expected_charge: float, + expected_spin: float, +) -> None: + """Test conversion from ASE Atoms with charge and spin to state tensors.""" + mol = molecule("H2O") + if charge is not None: + mol.info["charge"] = charge + if spin is not None: + mol.info["spin"] = spin + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (1,) + assert state.spin.shape == (1,) + assert state.charge[0].item() == expected_charge + assert state.spin[0].item() == expected_spin + + +def test_multiple_atoms_to_state_with_charge_spin() -> None: + """Test conversion from multiple ASE Atoms with different charge/spin values.""" + mol1 = molecule("H2O") + mol1.info["charge"] = 1.0 + mol1.info["spin"] = 1.0 + + mol2 = molecule("CH4") + mol2.info["charge"] = -1.0 + mol2.info["spin"] = 0.0 + + mol3 = molecule("NH3") + mol3.info["charge"] = 0.0 + mol3.info["spin"] = 2.0 + + state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (3,) + assert state.spin.shape == (3,) + assert state.charge[0].item() == 1.0 + assert state.charge[1].item() == -1.0 + assert state.charge[2].item() == 0.0 + assert state.spin[0].item() == 1.0 + assert state.spin[1].item() == 0.0 + assert state.spin[2].item() == 2.0 + + def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) @@ -114,6 +178,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: assert len(atoms[0]) == 32 +def test_state_to_atoms_with_charge_spin() -> None: + """Test conversion from state with charge/spin to ASE Atoms preserves charge/spin.""" + mol = molecule("H2O") + mol.info["charge"] = 1.0 + mol.info["spin"] = 1.0 + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + atoms = ts.io.state_to_atoms(state) + + assert len(atoms) == 1 + assert isinstance(atoms[0], Atoms) + assert "charge" in atoms[0].info + assert "spin" in atoms[0].info + assert atoms[0].info["charge"] == 1 + assert atoms[0].info["spin"] == 1 + + def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_double_sim_state) @@ -253,6 +334,9 @@ def test_state_round_trip( # since both use their own isotope masses based on species, # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) + # Check charge/spin round trip + assert torch.allclose(sim_state.charge, round_trip_state.charge) + assert torch.allclose(sim_state.spin, round_trip_state.spin) def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 5c6a243a..dbcd4448 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -200,7 +200,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 This validator creates small test systems (silicon and iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase.build import bulk + from ase.build import bulk, molecule for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): @@ -229,6 +229,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 og_cell = sim_state.cell.clone() og_system_idx = sim_state.system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() + og_charge = sim_state.charge.clone() + og_spin = sim_state.spin.clone() model_output = model.forward(sim_state) @@ -241,6 +243,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") + if not torch.allclose(og_charge, sim_state.charge): + raise ValueError(f"{og_charge=} != {sim_state.charge=}") + if not torch.allclose(og_spin, sim_state.spin): + raise ValueError(f"{og_spin=} != {sim_state.spin=}") # assert model output has the correct keys if "energy" not in model_output: @@ -300,3 +306,43 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)") if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)") + + # Test that models can handle non-zero charge and spin + benzene_atoms = molecule("C6H6") + benzene_atoms.info["charge"] = 1.0 + benzene_atoms.info["spin"] = 1.0 + charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype) + + # Ensure state has charge/spin before testing model + if charged_state.charge is None or charged_state.spin is None: + raise ValueError( + "atoms_to_state did not extract charge/spin. " + "Cannot test model charge/spin handling." + ) + + # Test that model can handle charge/spin without crashing + og_charged_charge = charged_state.charge.clone() + og_charged_spin = charged_state.spin.clone() + try: + charged_output = model.forward(charged_state) + except Exception as e: + raise ValueError( + "Model failed to handle non-zero charge/spin. " + "Models must be able to process states with charge and spin values. " + ) from e + + # Verify model didn't mutate charge/spin + if not torch.allclose(og_charged_charge, charged_state.charge): + raise ValueError( + f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}" + ) + if not torch.allclose(og_charged_spin, charged_state.spin): + raise ValueError( + f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}" + ) + # Verify output shape is still correct + if charged_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect with charge/spin: " + f"{charged_output['energy'].shape=} != (1,)" + )