Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from ase import Atoms
from ase.build import molecule
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Structure

Expand Down Expand Up @@ -88,6 +89,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None:
)


@pytest.mark.parametrize(
("charge", "spin", "expected_charge", "expected_spin"),
[
(1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin
(0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin
(None, None, 0.0, 0.0), # No charge/spin set, defaults to zero
],
)
def test_atoms_to_state_with_charge_spin(
charge: float | None,
spin: float | None,
expected_charge: float,
expected_spin: float,
) -> None:
"""Test conversion from ASE Atoms with charge and spin to state tensors."""
mol = molecule("H2O")
if charge is not None:
mol.info["charge"] = charge
if spin is not None:
mol.info["spin"] = spin

state = ts.io.atoms_to_state([mol], DEVICE, DTYPE)

# Check basic properties
assert isinstance(state, SimState)
assert state.charge is not None
assert state.spin is not None
assert state.charge.shape == (1,)
assert state.spin.shape == (1,)
assert state.charge[0].item() == expected_charge
assert state.spin[0].item() == expected_spin


def test_multiple_atoms_to_state_with_charge_spin() -> None:
"""Test conversion from multiple ASE Atoms with different charge/spin values."""
mol1 = molecule("H2O")
mol1.info["charge"] = 1.0
mol1.info["spin"] = 1.0

mol2 = molecule("CH4")
mol2.info["charge"] = -1.0
mol2.info["spin"] = 0.0

mol3 = molecule("NH3")
mol3.info["charge"] = 0.0
mol3.info["spin"] = 2.0

state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE)

# Check basic properties
assert isinstance(state, SimState)
assert state.charge is not None
assert state.spin is not None
assert state.charge.shape == (3,)
assert state.spin.shape == (3,)
assert state.charge[0].item() == 1.0
assert state.charge[1].item() == -1.0
assert state.charge[2].item() == 0.0
assert state.spin[0].item() == 1.0
assert state.spin[1].item() == 0.0
assert state.spin[2].item() == 2.0


def test_state_to_structure(ar_supercell_sim_state: SimState) -> None:
"""Test conversion from state tensors to list of pymatgen Structure."""
structures = ts.io.state_to_structures(ar_supercell_sim_state)
Expand All @@ -114,6 +178,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None:
assert len(atoms[0]) == 32


def test_state_to_atoms_with_charge_spin() -> None:
"""Test conversion from state with charge/spin to ASE Atoms preserves charge/spin."""
mol = molecule("H2O")
mol.info["charge"] = 1.0
mol.info["spin"] = 1.0

state = ts.io.atoms_to_state([mol], DEVICE, DTYPE)
atoms = ts.io.state_to_atoms(state)

assert len(atoms) == 1
assert isinstance(atoms[0], Atoms)
assert "charge" in atoms[0].info
assert "spin" in atoms[0].info
assert atoms[0].info["charge"] == 1
assert atoms[0].info["spin"] == 1


def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None:
"""Test conversion from state tensors to list of ASE Atoms."""
atoms = ts.io.state_to_atoms(ar_double_sim_state)
Expand Down Expand Up @@ -253,6 +334,9 @@ def test_state_round_trip(
# since both use their own isotope masses based on species,
# not the ones in the state
assert torch.allclose(sim_state.masses, round_trip_state.masses)
# Check charge/spin round trip
assert torch.allclose(sim_state.charge, round_trip_state.charge)
assert torch.allclose(sim_state.spin, round_trip_state.spin)


def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down
48 changes: 47 additions & 1 deletion torch_sim/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def validate_model_outputs( # noqa: C901, PLR0915
This validator creates small test systems (silicon and iron) for validation.
It tests both single and multi-batch processing capabilities.
"""
from ase.build import bulk
from ase.build import bulk, molecule

for attr in ("dtype", "device", "compute_stress", "compute_forces"):
if not hasattr(model, attr):
Expand Down Expand Up @@ -229,6 +229,8 @@ def validate_model_outputs( # noqa: C901, PLR0915
og_cell = sim_state.cell.clone()
og_system_idx = sim_state.system_idx.clone()
og_atomic_nums = sim_state.atomic_numbers.clone()
og_charge = sim_state.charge.clone()
og_spin = sim_state.spin.clone()

model_output = model.forward(sim_state)

Expand All @@ -241,6 +243,10 @@ def validate_model_outputs( # noqa: C901, PLR0915
raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}")
if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers):
raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}")
if not torch.allclose(og_charge, sim_state.charge):
raise ValueError(f"{og_charge=} != {sim_state.charge=}")
if not torch.allclose(og_spin, sim_state.spin):
raise ValueError(f"{og_spin=} != {sim_state.spin=}")

# assert model output has the correct keys
if "energy" not in model_output:
Expand Down Expand Up @@ -300,3 +306,43 @@ def validate_model_outputs( # noqa: C901, PLR0915
raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)")
if stress_computed and fe_model_output["stress"].shape != (1, 3, 3):
raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)")

# Test that models can handle non-zero charge and spin
benzene_atoms = molecule("C6H6")
benzene_atoms.info["charge"] = 1.0
benzene_atoms.info["spin"] = 1.0
charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype)

# Ensure state has charge/spin before testing model
if charged_state.charge is None or charged_state.spin is None:
raise ValueError(
"atoms_to_state did not extract charge/spin. "
"Cannot test model charge/spin handling."
)

# Test that model can handle charge/spin without crashing
og_charged_charge = charged_state.charge.clone()
og_charged_spin = charged_state.spin.clone()
try:
charged_output = model.forward(charged_state)
except Exception as e:
raise ValueError(
"Model failed to handle non-zero charge/spin. "
"Models must be able to process states with charge and spin values. "
) from e

# Verify model didn't mutate charge/spin
if not torch.allclose(og_charged_charge, charged_state.charge):
raise ValueError(
f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}"
)
if not torch.allclose(og_charged_spin, charged_state.spin):
raise ValueError(
f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}"
)
# Verify output shape is still correct
if charged_output["energy"].shape != (1,):
raise ValueError(
f"energy shape incorrect with charge/spin: "
f"{charged_output['energy'].shape=} != (1,)"
)
Loading