From ebb791d9a60e46359ed9c91ca0c2f3dc1473c7bf Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Sat, 31 Jan 2026 01:01:26 -0800 Subject: [PATCH 1/4] Proper wrapping, address #423 --- tests/models/test_lennard_jones.py | 73 ++++++++++++++++++++++++++++++ torch_sim/models/graphpes.py | 14 +++++- torch_sim/models/lennard_jones.py | 31 ++++++++----- torch_sim/models/mace.py | 16 ++++++- torch_sim/models/morse.py | 30 ++++++++---- torch_sim/models/particle_life.py | 27 +++++++---- torch_sim/models/sevennet.py | 16 ++++++- torch_sim/models/soft_sphere.py | 27 +++++++---- torch_sim/neighbors/torch_nl.py | 13 +++--- 9 files changed, 194 insertions(+), 53 deletions(-) diff --git a/tests/models/test_lennard_jones.py b/tests/models/test_lennard_jones.py index 52435a75..4df30130 100644 --- a/tests/models/test_lennard_jones.py +++ b/tests/models/test_lennard_jones.py @@ -234,3 +234,76 @@ def test_stress_tensor_symmetry( def test_validate_model_outputs(lj_model: LennardJonesModel) -> None: """Test that the model outputs are valid.""" validate_model_outputs(lj_model, DEVICE, torch.float64) + + +def test_unwrapped_positions_consistency() -> None: + """Test that wrapped and unwrapped positions give identical results. + + This tests that models correctly handle positions outside the unit cell + by wrapping them before neighbor list computation. + """ + # Create a periodic system + ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2]) + cell = torch.tensor(ar_atoms.get_cell().array, dtype=torch.float64, device=DEVICE) + + # Create wrapped state (positions inside unit cell) + state_wrapped = ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64) + + # Create unwrapped state by shifting some atoms outside the cell + positions_unwrapped = state_wrapped.positions.clone() + # Shift first half of atoms by +1 cell vector in x direction + n_atoms = positions_unwrapped.shape[0] + positions_unwrapped[: n_atoms // 2] += cell[0] + # Shift some atoms by -1 cell vector in y direction + positions_unwrapped[n_atoms // 4 : n_atoms // 2] -= cell[1] + + state_unwrapped = ts.SimState( + positions=positions_unwrapped, + masses=state_wrapped.masses, + cell=state_wrapped.cell, + pbc=state_wrapped.pbc, + atomic_numbers=state_wrapped.atomic_numbers, + ) + + # Create model + model = LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + dtype=torch.float64, + device=DEVICE, + compute_forces=True, + compute_stress=True, + use_neighbor_list=True, + ) + + # Compute results + results_wrapped = model(state_wrapped) + results_unwrapped = model(state_unwrapped) + + # Verify energy matches + torch.testing.assert_close( + results_wrapped["energy"], + results_unwrapped["energy"], + rtol=1e-10, + atol=1e-10, + msg="Energies should match for wrapped and unwrapped positions", + ) + + # Verify forces match + torch.testing.assert_close( + results_wrapped["forces"], + results_unwrapped["forces"], + rtol=1e-10, + atol=1e-10, + msg="Forces should match for wrapped and unwrapped positions", + ) + + # Verify stress matches + torch.testing.assert_close( + results_wrapped["stress"], + results_unwrapped["stress"], + rtol=1e-10, + atol=1e-10, + msg="Stress should match for wrapped and unwrapped positions", + ) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index e380cab4..db2b43c8 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -64,11 +64,23 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra Returns: AtomicGraph object representing the batched structures """ + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched( + state.positions, + state.cell, + state.system_idx, + state.pbc, + ) + if state.pbc.any() + else state.positions + ) + graphs = [] for sys_idx in range(state.n_systems): system_mask = state.system_idx == sys_idx - R = state.positions[system_mask] + R = wrapped_positions[system_mask] Z = state.atomic_numbers[system_mask] cell = state.row_vector_cell[sys_idx] # graph-pes models internally trim the neighbor list to the diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 22d1c088..bbf1678e 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -260,16 +260,23 @@ def unbatched_forward( cell = cell.squeeze() pbc = state.pbc + # Ensure system_idx exists (create if None for single system) + system_idx = ( + state.system_idx + if state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc) + if pbc.any() + else positions + ) + if self.use_neighbor_list: - # Get neighbor list using torchsim_nl - # Ensure system_idx exists (create if None for single system) - system_idx = ( - state.system_idx - if state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) mapping, system_mapping, shifts_idx = torchsim_nl( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, cutoff=self.cutoff, @@ -277,7 +284,7 @@ def unbatched_forward( ) # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, pairs=(mapping[0], mapping[1]), @@ -286,10 +293,12 @@ def unbatched_forward( else: # Get all pairwise displacements dr_vec, distances = transforms.get_pair_displacements( - positions=positions, cell=cell, pbc=pbc + positions=wrapped_positions, cell=cell, pbc=pbc ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) + mask = torch.eye( + wrapped_positions.shape[0], dtype=torch.bool, device=self.device + ) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 889adfae..678405ef 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -300,9 +300,21 @@ def forward( # noqa: C901 ): self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched( + sim_state.positions, + sim_state.cell, + sim_state.system_idx, + sim_state.pbc, + ) + if sim_state.pbc.any() + else sim_state.positions + ) + # Batched neighbor list using linked-cell algorithm edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - sim_state.positions, + wrapped_positions, sim_state.row_vector_cell, sim_state.pbc, self.r_max, @@ -320,7 +332,7 @@ def forward( # noqa: C901 batch=sim_state.system_idx, pbc=sim_state.pbc, cell=sim_state.row_vector_cell, - positions=sim_state.positions, + positions=wrapped_positions, edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 513c3ed6..f10ca4bb 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -266,15 +266,23 @@ def unbatched_forward( cell = cell.squeeze() pbc = sim_state.pbc + # Ensure system_idx exists (create if None for single system) + system_idx = ( + sim_state.system_idx + if sim_state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched(positions, sim_state.cell, system_idx, pbc) + if pbc.any() + else positions + ) + if self.use_neighbor_list: - # Ensure system_idx exists (create if None for single system) - system_idx = ( - sim_state.system_idx - if sim_state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) mapping, system_mapping, shifts_idx = torchsim_nl( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, cutoff=self.cutoff, @@ -282,7 +290,7 @@ def unbatched_forward( ) # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, pairs=(mapping[0], mapping[1]), @@ -290,11 +298,13 @@ def unbatched_forward( ) else: dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, ) - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) + mask = torch.eye( + wrapped_positions.shape[0], dtype=torch.bool, device=self.device + ) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff i, j = torch.where(mask) diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 270f79f3..d6e48711 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -149,16 +149,23 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: if cell.dim() == 3: # Check if there is an extra batch dimension cell = cell.squeeze(0) # Squeeze the first dimension + # Ensure system_idx exists (create if None for single system) + system_idx = ( + state.system_idx + if state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc) + if pbc.any() + else positions + ) + if self.use_neighbor_list: - # Get neighbor list using torchsim_nl - # Ensure system_idx exists (create if None for single system) - system_idx = ( - state.system_idx - if state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) mapping, system_mapping, shifts_idx = torchsim_nl( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, cutoff=self.cutoff, @@ -166,7 +173,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: ) # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, pairs=(mapping[0], mapping[1]), @@ -175,7 +182,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: else: # Get all pairwise displacements dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, ) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 81b76820..5d0e1176 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -191,10 +191,22 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # TODO: is this clone necessary? sim_state = sim_state.clone() + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched( + sim_state.positions, + sim_state.cell, + sim_state.system_idx, + sim_state.pbc, + ) + if sim_state.pbc.any() + else sim_state.positions + ) + # Batched neighbor list using linked-cell algorithm with row-vector cell n_systems = sim_state.system_idx.max().item() + 1 edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - sim_state.positions, + wrapped_positions, sim_state.row_vector_cell, sim_state.pbc, self.cutoff, @@ -215,7 +227,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: sys_start = stride[sys_idx].item() sys_end = stride[sys_idx + 1].item() - pos = sim_state.positions[sys_start:sys_end] + pos = wrapped_positions[sys_start:sys_end] row_vector_cell = sim_state.row_vector_cell[sys_idx] atomic_nums = sim_state.atomic_numbers[sys_start:sys_end] diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 6dc9d1ce..bf45973a 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -284,16 +284,23 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: cell = cell.squeeze() pbc = state.pbc + # Ensure system_idx exists (create if None for single system) + system_idx = ( + state.system_idx + if state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + + # Wrap positions into the unit cell + wrapped_positions = ( + ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc) + if pbc.any() + else positions + ) + if self.use_neighbor_list: - # Get neighbor list using torchsim_nl - # Ensure system_idx exists (create if None for single system) - system_idx = ( - state.system_idx - if state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) mapping, system_mapping, shifts_idx = torchsim_nl( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, cutoff=self.cutoff, @@ -301,7 +308,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: ) # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, pairs=(mapping[0], mapping[1]), @@ -311,7 +318,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: else: # Direct N^2 computation of all pairs dr_vec, distances = transforms.get_pair_displacements( - positions=positions, + positions=wrapped_positions, cell=cell, pbc=pbc, ) diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py index e8ddec4f..34db8f19 100644 --- a/torch_sim/neighbors/torch_nl.py +++ b/torch_sim/neighbors/torch_nl.py @@ -112,7 +112,6 @@ def strict_nl( return mapping, mapping_system, shifts_idx -@torch.jit.script def torch_nl_n2( positions: torch.Tensor, cell: torch.Tensor, @@ -172,6 +171,7 @@ def torch_nl_n2( """ n_systems = system_idx.max().item() + 1 cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) + n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( positions, cell, pbc, cutoff.item(), n_atoms, self_interaction @@ -182,28 +182,26 @@ def torch_nl_n2( return mapping, mapping_system, shifts_idx -@torch.jit.script def torch_nl_linked_cell( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, cutoff: torch.Tensor, system_idx: torch.Tensor, - self_interaction: bool = False, # noqa: FBT001, FBT002 (*, not compatible with torch.jit.script) + self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using the linked cell algorithm before applying a strict `cutoff`. - The atoms positions `pos` should be wrapped inside their respective unit cells. + The atomic positions `pos` should be wrapped inside their respective unit cells. This is the recommended default for batched neighbor list calculations as it provides good performance for systems of various sizes using the linked cell algorithm which has O(N) complexity. Args: - positions (torch.Tensor [n_atom, 3]): - A tensor containing the positions of atoms wrapped inside - their respective unit cells. + positions (torch.Tensor [n_atom, 3]): A tensor containing the positions + of atoms wrapped inside their respective unit cells. cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. @@ -245,6 +243,7 @@ def torch_nl_linked_cell( """ n_systems = system_idx.max().item() + 1 cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) + n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( positions, cell, pbc, cutoff.item(), n_atoms, self_interaction From 9c1fed7d6d912f94bc1fc4180b39bf927fee56b0 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Sat, 31 Jan 2026 01:25:33 -0800 Subject: [PATCH 2/4] test --- torch_sim/models/graphpes.py | 4 ++-- torch_sim/models/sevennet.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index db2b43c8..236cf34c 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -65,7 +65,7 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra AtomicGraph object representing the batched structures """ # Wrap positions into the unit cell - wrapped_positions = ( + wrapped_positions = ( # noqa: F841 ts.transforms.pbc_wrap_batched( state.positions, state.cell, @@ -80,7 +80,7 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra for sys_idx in range(state.n_systems): system_mask = state.system_idx == sys_idx - R = wrapped_positions[system_mask] + R = state.positions[system_mask] Z = state.atomic_numbers[system_mask] cell = state.row_vector_cell[sys_idx] # graph-pes models internally trim the neighbor list to the diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 5d0e1176..df4ac343 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -227,7 +227,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: sys_start = stride[sys_idx].item() sys_end = stride[sys_idx + 1].item() - pos = wrapped_positions[sys_start:sys_end] + pos = sim_state.positions[sys_start:sys_end] row_vector_cell = sim_state.row_vector_cell[sys_idx] atomic_nums = sim_state.atomic_numbers[sys_start:sys_end] From 6561ea8441e91b0f5befb1affe2820c06c416331 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Sat, 31 Jan 2026 01:48:38 -0800 Subject: [PATCH 3/4] use wrapped --- torch_sim/models/sevennet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index df4ac343..5d0e1176 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -227,7 +227,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: sys_start = stride[sys_idx].item() sys_end = stride[sys_idx + 1].item() - pos = sim_state.positions[sys_start:sys_end] + pos = wrapped_positions[sys_start:sys_end] row_vector_cell = sim_state.row_vector_cell[sys_idx] atomic_nums = sim_state.atomic_numbers[sys_start:sys_end] From 4f8355262aa7dcb969a86d54e1305582f993a5f1 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Sat, 31 Jan 2026 02:05:15 -0800 Subject: [PATCH 4/4] revert --- torch_sim/models/graphpes.py | 12 ------------ torch_sim/models/sevennet.py | 16 ++-------------- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 236cf34c..e380cab4 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -64,18 +64,6 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra Returns: AtomicGraph object representing the batched structures """ - # Wrap positions into the unit cell - wrapped_positions = ( # noqa: F841 - ts.transforms.pbc_wrap_batched( - state.positions, - state.cell, - state.system_idx, - state.pbc, - ) - if state.pbc.any() - else state.positions - ) - graphs = [] for sys_idx in range(state.n_systems): diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 5d0e1176..81b76820 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -191,22 +191,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # TODO: is this clone necessary? sim_state = sim_state.clone() - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched( - sim_state.positions, - sim_state.cell, - sim_state.system_idx, - sim_state.pbc, - ) - if sim_state.pbc.any() - else sim_state.positions - ) - # Batched neighbor list using linked-cell algorithm with row-vector cell n_systems = sim_state.system_idx.max().item() + 1 edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - wrapped_positions, + sim_state.positions, sim_state.row_vector_cell, sim_state.pbc, self.cutoff, @@ -227,7 +215,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: sys_start = stride[sys_idx].item() sys_end = stride[sys_idx + 1].item() - pos = wrapped_positions[sys_start:sys_end] + pos = sim_state.positions[sys_start:sys_end] row_vector_cell = sim_state.row_vector_cell[sys_idx] atomic_nums = sim_state.atomic_numbers[sys_start:sys_end]