Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/models/test_lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,76 @@ def test_stress_tensor_symmetry(
def test_validate_model_outputs(lj_model: LennardJonesModel) -> None:
"""Test that the model outputs are valid."""
validate_model_outputs(lj_model, DEVICE, torch.float64)


def test_unwrapped_positions_consistency() -> None:
"""Test that wrapped and unwrapped positions give identical results.

This tests that models correctly handle positions outside the unit cell
by wrapping them before neighbor list computation.
"""
# Create a periodic system
ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2])
cell = torch.tensor(ar_atoms.get_cell().array, dtype=torch.float64, device=DEVICE)

# Create wrapped state (positions inside unit cell)
state_wrapped = ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64)

# Create unwrapped state by shifting some atoms outside the cell
positions_unwrapped = state_wrapped.positions.clone()
# Shift first half of atoms by +1 cell vector in x direction
n_atoms = positions_unwrapped.shape[0]
positions_unwrapped[: n_atoms // 2] += cell[0]
# Shift some atoms by -1 cell vector in y direction
positions_unwrapped[n_atoms // 4 : n_atoms // 2] -= cell[1]

state_unwrapped = ts.SimState(
positions=positions_unwrapped,
masses=state_wrapped.masses,
cell=state_wrapped.cell,
pbc=state_wrapped.pbc,
atomic_numbers=state_wrapped.atomic_numbers,
)

# Create model
model = LennardJonesModel(
sigma=3.405,
epsilon=0.0104,
cutoff=2.5 * 3.405,
dtype=torch.float64,
device=DEVICE,
compute_forces=True,
compute_stress=True,
use_neighbor_list=True,
)

# Compute results
results_wrapped = model(state_wrapped)
results_unwrapped = model(state_unwrapped)

# Verify energy matches
torch.testing.assert_close(
results_wrapped["energy"],
results_unwrapped["energy"],
rtol=1e-10,
atol=1e-10,
msg="Energies should match for wrapped and unwrapped positions",
)

# Verify forces match
torch.testing.assert_close(
results_wrapped["forces"],
results_unwrapped["forces"],
rtol=1e-10,
atol=1e-10,
msg="Forces should match for wrapped and unwrapped positions",
)

# Verify stress matches
torch.testing.assert_close(
results_wrapped["stress"],
results_unwrapped["stress"],
rtol=1e-10,
atol=1e-10,
msg="Stress should match for wrapped and unwrapped positions",
)
31 changes: 20 additions & 11 deletions torch_sim/models/lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,24 +260,31 @@ def unbatched_forward(
cell = cell.squeeze()
pbc = state.pbc

# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)

# Wrap positions into the unit cell
wrapped_positions = (
ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc)
if pbc.any()
else positions
)

if self.use_neighbor_list:
# Get neighbor list using torchsim_nl
# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
system_idx=system_idx,
)
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
pairs=(mapping[0], mapping[1]),
Expand All @@ -286,10 +293,12 @@ def unbatched_forward(
else:
# Get all pairwise displacements
dr_vec, distances = transforms.get_pair_displacements(
positions=positions, cell=cell, pbc=pbc
positions=wrapped_positions, cell=cell, pbc=pbc
)
# Mask out self-interactions
mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
mask = torch.eye(
wrapped_positions.shape[0], dtype=torch.bool, device=self.device
)
distances = distances.masked_fill(mask, float("inf"))
# Apply cutoff
mask = distances < self.cutoff
Expand Down
16 changes: 14 additions & 2 deletions torch_sim/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,21 @@ def forward( # noqa: C901
):
self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx)

# Wrap positions into the unit cell
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the alternative of moving thing inside the nl function?

wrapped_positions = (
ts.transforms.pbc_wrap_batched(
sim_state.positions,
sim_state.cell,
sim_state.system_idx,
sim_state.pbc,
)
if sim_state.pbc.any()
else sim_state.positions
)

# Batched neighbor list using linked-cell algorithm
edge_index, mapping_system, unit_shifts = self.neighbor_list_fn(
sim_state.positions,
wrapped_positions,
sim_state.row_vector_cell,
sim_state.pbc,
self.r_max,
Expand All @@ -320,7 +332,7 @@ def forward( # noqa: C901
batch=sim_state.system_idx,
pbc=sim_state.pbc,
cell=sim_state.row_vector_cell,
positions=sim_state.positions,
positions=wrapped_positions,
edge_index=edge_index,
unit_shifts=unit_shifts,
shifts=shifts,
Expand Down
30 changes: 20 additions & 10 deletions torch_sim/models/morse.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,35 +266,45 @@ def unbatched_forward(
cell = cell.squeeze()
pbc = sim_state.pbc

# Ensure system_idx exists (create if None for single system)
system_idx = (
sim_state.system_idx
if sim_state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)

# Wrap positions into the unit cell
wrapped_positions = (
ts.transforms.pbc_wrap_batched(positions, sim_state.cell, system_idx, pbc)
if pbc.any()
else positions
)

if self.use_neighbor_list:
# Ensure system_idx exists (create if None for single system)
system_idx = (
sim_state.system_idx
if sim_state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
system_idx=system_idx,
)
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
pairs=(mapping[0], mapping[1]),
shifts=shifts_idx,
)
else:
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
)
mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
mask = torch.eye(
wrapped_positions.shape[0], dtype=torch.bool, device=self.device
)
distances = distances.masked_fill(mask, float("inf"))
mask = distances < self.cutoff
i, j = torch.where(mask)
Expand Down
27 changes: 17 additions & 10 deletions torch_sim/models/particle_life.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,31 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
if cell.dim() == 3: # Check if there is an extra batch dimension
cell = cell.squeeze(0) # Squeeze the first dimension

# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)

# Wrap positions into the unit cell
wrapped_positions = (
ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc)
if pbc.any()
else positions
)

if self.use_neighbor_list:
# Get neighbor list using torchsim_nl
# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
system_idx=system_idx,
)
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
pairs=(mapping[0], mapping[1]),
Expand All @@ -175,7 +182,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
else:
# Get all pairwise displacements
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
)
Expand Down
27 changes: 17 additions & 10 deletions torch_sim/models/soft_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,24 +284,31 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
cell = cell.squeeze()
pbc = state.pbc

# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)

# Wrap positions into the unit cell
wrapped_positions = (
ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc)
if pbc.any()
else positions
)

if self.use_neighbor_list:
# Get neighbor list using torchsim_nl
# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
system_idx=system_idx,
)
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
pairs=(mapping[0], mapping[1]),
Expand All @@ -311,7 +318,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
else:
# Direct N^2 computation of all pairs
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
positions=wrapped_positions,
cell=cell,
pbc=pbc,
)
Expand Down
13 changes: 6 additions & 7 deletions torch_sim/neighbors/torch_nl.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def strict_nl(
return mapping, mapping_system, shifts_idx


@torch.jit.script
def torch_nl_n2(
positions: torch.Tensor,
cell: torch.Tensor,
Expand Down Expand Up @@ -172,6 +171,7 @@ def torch_nl_n2(
"""
n_systems = system_idx.max().item() + 1
cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems)

n_atoms = torch.bincount(system_idx)
mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood(
positions, cell, pbc, cutoff.item(), n_atoms, self_interaction
Expand All @@ -182,28 +182,26 @@ def torch_nl_n2(
return mapping, mapping_system, shifts_idx


@torch.jit.script
def torch_nl_linked_cell(
positions: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
cutoff: torch.Tensor,
system_idx: torch.Tensor,
self_interaction: bool = False, # noqa: FBT001, FBT002 (*, not compatible with torch.jit.script)
self_interaction: bool = False, # noqa: FBT001, FBT002
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the neighbor list for a set of atomic structures using the linked
cell algorithm before applying a strict `cutoff`.

The atoms positions `pos` should be wrapped inside their respective unit cells.
The atomic positions `pos` should be wrapped inside their respective unit cells.

This is the recommended default for batched neighbor list calculations as it
provides good performance for systems of various sizes using the linked cell
algorithm which has O(N) complexity.

Args:
positions (torch.Tensor [n_atom, 3]):
A tensor containing the positions of atoms wrapped inside
their respective unit cells.
positions (torch.Tensor [n_atom, 3]): A tensor containing the positions
of atoms wrapped inside their respective unit cells.
cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors.
pbc (torch.Tensor [n_systems, 3] bool):
A tensor indicating the periodic boundary conditions to apply.
Expand Down Expand Up @@ -245,6 +243,7 @@ def torch_nl_linked_cell(
"""
n_systems = system_idx.max().item() + 1
cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems)

n_atoms = torch.bincount(system_idx)
mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood(
positions, cell, pbc, cutoff.item(), n_atoms, self_interaction
Expand Down