Skip to content

Wrapped/unwrapped positions produce different results #423

@chlwjd1234

Description

@chlwjd1234

Using Sevennet and MACE, I observed that simulation results change depending on whether atomic positions are wrapped into the cell.

This can be easily reproduced by the following code and two equivalent but wrapped/unwrapped structure files.

POSCAR_wrapped
POSCAR file written by OVITO Basic 3.14.1
1
12.1886482239 0.0 0.0
0.0 12.1886482239 0.0
0.0 0.0 12.1886482239
Si C N 
40 40 40 
Cartesian
1.62960804 7.81111288 11.52471447
0.10904183 11.84466934 10.90873718
9.15765572 0.42337838 10.97032547
9.93355465 3.08735299 10.84395313
0.22936374 2.51467252 0.00688695
6.10603428 2.66809821 6.50035667
9.18716908 5.29740238 0.84695166
3.63698077 10.76674652 3.37370563
10.13137341 1.99530399 4.41941452
11.22895622 0.70702899 1.9626857
6.85616112 8.43644238 2.80080009
1.92299962 0.67620301 4.25293064
7.53706646 4.58360958 11.76811314
11.18088341 4.62180805 4.59584332
0.04918295 3.16794014 8.98231316
0.84643483 9.23329067 6.38460732
5.9513917 11.80080032 3.85442424
10.62050819 9.05090523 4.90256119
2.18369842 4.12739229 3.88087988
10.20464706 10.77087212 8.8745594
2.48243117 7.23433065 1.58311677
7.89527082 3.53034306 4.04023361
4.13870525 9.32768822 11.5232563
1.26553833 4.66457653 6.83879137
3.13174224 6.05740261 11.15625477
11.0031147 6.96509409 11.2675209
3.01780796 4.68837023 7.95611
8.68793774 3.49162698 8.47819042
9.20871639 4.19576883 5.84572124
2.04565835 2.17499804 6.02265453
10.26705742 8.612607 8.92945766
6.72325945 1.3569802 10.16003227
8.28885269 9.26705647 0.35722363
11.58921146 3.34976077 2.98282909
5.72420597 6.34954786 10.00457859
4.19156075 6.19931269 3.00485206
6.07885361 11.66264629 8.30989075
3.57719469 6.17771053 6.01888514
6.44051933 6.75858593 5.59578705
1.38739586 4.79369783 11.62009716
9.65887547 6.37036753 4.19613981
10.39524269 5.04818583 9.42271328
1.96724594 11.54160213 1.81917989
8.80187607 9.85273933 10.80256844
8.26953983 5.94302654 4.6990962
7.56050301 1.8478893 5.42749596
6.90789747 9.31377316 7.34465885
1.12333548 11.18586445 9.67084885
1.94813681 9.85089111 8.8954134
2.24054623 5.07671165 5.53983402
11.36300659 0.72358876 5.13564205
1.41668057 9.62051868 10.24730873
8.60816193 11.18423367 4.90922403
2.57010317 4.53859425 1.94834638
10.0786562 4.60001898 8.04714775
5.05244875 3.99494553 0.90023482
6.15318489 6.10743618 7.32012415
4.29062653 0.64726686 6.45087337
11.62414932 2.72240663 10.64064312
4.03313446 8.04793167 4.59394741
8.20511055 8.65339088 10.93052387
7.15570307 6.48882389 10.92961407
2.23792076 0.19930504 10.74730301
11.4533968 4.5184865 7.89275408
8.6868391 2.07194257 9.86865807
4.770473 1.85908628 7.82702446
6.60151768 3.26577377 10.85124397
5.76976585 4.14985037 11.87876225
10.08633423 4.74804926 3.45640206
0.0846159 0.38691813 4.37168026
3.31751156 4.0094552 0.78795934
2.81696057 4.98822689 9.65145302
8.19404221 7.25284624 11.22687912
3.20952439 9.66926765 4.95219994
9.55809689 11.60817528 5.6843462
10.89503098 11.66533947 5.05719995
5.22580194 7.88416672 1.96151078
5.36859226 0.11606334 5.68553257
11.76816273 10.96197128 7.01056385
5.23203278 3.5338273 2.27720881
9.12883759 4.25872374 9.9096756
6.29451036 2.96083689 2.91030073
4.48034859 4.21144962 3.22216177
7.12371206 0.4596428 1.50043249
7.4607811 10.87577629 4.46407318
4.27106571 7.56306553 10.2054348
1.81227088 0.64554459 11.81356716
0.55391037 4.36665392 3.53468585
11.64720154 8.40726376 6.25737906
2.19480205 9.0409174 7.94733953
1.73580217 10.30322647 5.15625668
9.09574604 6.66356993 12.06156063
11.3837347 11.53053665 8.03412724
11.28249836 8.07641792 1.52284324
11.37964916 5.24373579 10.2876749
3.57487345 2.09956479 11.19370556
1.24165726 6.79926777 2.89407659
6.41543388 1.22874427 1.22649252
4.79091597 6.73067904 4.58354235
0.45052758 2.99550295 6.24490547
5.36062193 4.45093822 6.27340031
2.93633318 8.25693035 4.26812506
12.02402401 2.33154082 1.65588093
10.15573406 7.94815302 1.38282299
4.06420469 8.45156574 1.57161748
8.10657406 3.51286101 6.63001823
3.35943961 2.10928988 10.11426258
9.32555485 10.89941788 10.33946133
5.66962481 10.68447685 2.66697049
3.57186198 2.4982779 7.51835966
9.06202316 4.02586651 2.94311142
7.07251549 9.92278767 8.35058403
2.61854649 11.56851292 9.86461639
2.23146915 10.34848404 2.30889535
0.2191238 7.24786568 3.62774277
5.73681307 6.57291698 8.42938137
2.43016076 2.2052331 3.83330965
2.61398745 9.24248981 10.80654812
1.70004296 4.53833961 9.8129282
10.28528595 4.05022049 0.25640708
POSCAR_unwrapped
POSCAR file written by OVITO Basic 3.14.1
1
12.1886482239 0.0 0.0
0.0 12.1886482239 0.0
0.0 0.0 12.1886482239
Si C N 
40 40 40 
Cartesian
13.8182562639 32.1884093278 23.7133626939
0.10904183 11.84466934 10.90873718
9.15765572 0.42337838 -1.2183227539
-14.4437417978 15.2760012139 -1.3446950939
0.22936374 -9.6739757039 -24.3704094978
6.10603428 -9.5205500139 -5.6882915539
9.18716908 17.4860506039 0.84695166
-8.5516674539 10.76674652 3.37370563
10.13137341 14.1839522139 4.41941452
-0.9596920039 0.70702899 1.9626857
31.2334575678 8.43644238 -9.3878481339
1.92299962 0.67620301 16.4415788639
7.53706646 -7.6050386439 -0.4205350839
11.18088341 16.8104562739 4.59584332
0.04918295 3.16794014 -3.2063350639
13.0350830539 33.6105871178 6.38460732
18.1400399239 23.9894485439 3.85442424
10.62050819 9.05090523 -7.2860870339
2.18369842 28.5046887378 16.0695281039
10.20464706 10.77087212 -15.5027370478
2.48243117 7.23433065 1.58311677
7.89527082 -8.6583051639 4.04023361
4.13870525 -15.0496082278 11.5232563
1.26553833 4.66457653 6.83879137
-9.0569059839 -6.1312456139 11.15625477
-1.1855335239 6.96509409 -0.9211273239
3.01780796 29.0656666778 7.95611
-3.5007104839 15.6802752039 -3.7104578039
-2.9799318339 4.19576883 5.84572124
14.2343065739 14.3636462639 18.2113027539
-1.9215908039 -3.5760412239 -3.2591905639
6.72325945 1.3569802 22.3486804939
-3.8997955339 9.26705647 12.5458718539
-0.5994367639 3.34976077 -9.2058191339
-6.4644422539 -5.8391003639 10.00457859
28.5688571978 -5.9893355339 3.00485206
6.07885361 11.66264629 8.30989075
-8.6114535339 -6.0109376939 -6.1697630839
-17.9367771178 -5.4300622939 5.59578705
13.5760440839 4.79369783 -0.5685510639
9.65887547 6.37036753 4.19613981
10.39524269 -7.1404623939 -2.7659349439
14.1558941639 -0.6470460939 14.0078281139
8.80187607 9.85273933 22.9912166639
8.26953983 5.94302654 -7.4895520239
7.56050301 1.8478893 17.6161441839
-5.2807507539 -15.0635232878 -4.8439893739
13.3119837039 -1.0027837739 9.67084885
1.94813681 9.85089111 8.8954134
14.4291944539 17.2653598739 17.7284822439
-0.8256416339 12.9122369839 29.5129384978
1.41668057 9.62051868 -1.9413394939
-15.7691345178 -13.1930627778 4.90922403
2.57010317 4.53859425 1.94834638
10.0786562 16.7886672039 8.04714775
17.2410969739 16.1835937539 13.0888830439
18.3418331139 -6.0812120439 7.32012415
16.4792747539 0.64726686 6.45087337
23.8127975439 -9.4662415939 10.64064312
4.03313446 -4.1407165539 4.59394741
8.20511055 20.8420391039 10.93052387
7.15570307 -5.6998243339 -1.2590341539
2.23792076 0.19930504 -1.4413452139
-0.7352514239 16.7071347239 7.89275408
20.8754873239 2.07194257 9.86865807
4.770473 1.85908628 20.0156726839
6.60151768 3.26577377 10.85124397
17.9584140739 4.14985037 11.87876225
10.08633423 16.9366974839 15.6450502839
0.0846159 12.5755663539 16.5603284839
15.5061597839 28.3867516478 12.9766075639
15.0056087939 4.98822689 9.65145302
20.3826904339 7.25284624 -0.9617691039
3.20952439 9.66926765 4.95219994
21.7467451139 11.60817528 30.0616426478
10.89503098 -0.5233087539 17.2458481739
17.4144501639 -4.3044815039 14.1501590039
5.36859226 24.4933597878 5.68553257
-0.4204854939 -1.2266769439 19.1992120739
17.4206810039 -8.6548209239 26.6545052578
9.12883759 16.4473719639 22.0983238239
-5.8941378639 15.1494851139 15.0989489539
4.48034859 4.21144962 27.5994582178
7.12371206 12.6482910239 -10.6882157339
19.6494293239 -1.3128719339 16.6527214039
4.27106571 19.7517137539 -1.9832134239
14.0009191039 0.64554459 11.81356716
12.7425585939 4.36665392 15.7233340739
11.64720154 -3.7813844639 6.25737906
2.19480205 9.0409174 20.1359877539
-10.4528460539 10.30322647 -7.0323915439
9.09574604 18.8522181539 -0.1270875939
-0.8049135239 11.53053665 8.03412724
23.4711465839 -4.1122303039 13.7114914639
23.5682973839 -6.9449124339 10.2876749
15.7635216739 14.2882130139 -0.9949426639
13.4303054839 -5.3893804539 2.89407659
18.6040821039 -10.9599039539 1.22649252
16.9795641939 6.73067904 4.58354235
12.6391758039 2.99550295 6.24490547
5.36062193 -7.7377100039 -18.1038961378
2.93633318 -3.9317178739 16.4567732839
-0.1646242139 14.5201890439 -10.5327672939
-2.0329141639 20.1368012439 1.38282299
4.06420469 20.6402139639 13.7602657039
8.10657406 3.51286101 31.0073146778
3.35943961 14.2979381039 -2.0743856439
-15.0517415978 -1.2892303439 10.33946133
17.8582730339 10.68447685 -9.5216777339
15.7605102039 2.4982779 7.51835966
-3.1266250639 4.02586651 2.94311142
19.2611637139 22.1114358939 -3.8380641939
2.61854649 11.56851292 22.0532646139
14.4201173739 -1.8401641839 2.30889535
24.5964202478 7.24786568 3.62774277
5.73681307 6.57291698 8.42938137
14.6188089839 14.3938813239 3.83330965
-9.5746607739 -2.9461584139 10.80654812
13.8886911839 4.53833961 -2.3757200239
10.28528595 4.05022049 0.25640708
test.py
import torch

from ase.io import read as ase_read
from torch_sim.state import initialize_state

from mace.calculators.foundations_models import mace_mp
from torch_sim.models.mace import MaceModel


def main():
    atoms_wrapped = ase_read("POSCAR_wrapped", format="vasp")
    atoms_unwrapped = ase_read("POSCAR_unwrapped", format="vasp")

    state_wrapped = initialize_state(atoms_wrapped, device="cuda", dtype=torch.float64)
    state_unwrapped = initialize_state(atoms_unwrapped, device="cuda", dtype=torch.float64)

    mace = mace_mp(model="./mace-mh-1.model", device="cuda", return_raw_model=True, dtype=torch.float64)
    model = MaceModel(
        model=mace,
        device="cuda",
        dtype=torch.float64,
        compute_forces=True,
    )

    out_wrapped = model(state_wrapped)
    out_unwrapped = model(state_unwrapped)

    print(out_wrapped["energy"])
    print(out_unwrapped["energy"])


if __name__ == "__main__":
    main()

Output:

tensor([-957.6046], device='cuda:0', dtype=torch.float64)
tensor([-812.6858], device='cuda:0', dtype=torch.float64)

The root cause seems to be that SevenNet/MACE graph construction requires all positions inside the cell, leading to different neighbor lists between wrapped and unwrapped cases; other models have not been tested.

This is not only an input-preprocessing issue: even if the initial configuration is wrapped, during MD or optimization, atoms can cross periodic boundary. Once positions go outside the primary cell, the same wrap-dependence can occur again, and model evaluation becomes incorrect from that point.

In high-temperature MD (where most atoms cross the boundary), I observed unphysical clustered structures while the same simulation does not show such behavior using LAMMPS.

I made a simple workaround for single-point calculation and NVT MD by ensuring wrapping during initialization and on every MD step. I’m not sure this workaround is optimal but it resolves the issue in my tests.

Patch: torch_sim/state.py
--- old/torch_sim/state.py
+++ new/torch_sim/state.py
@@ -1079,8 +1079,25 @@ def initialize_state(
     """
     # TODO: create a way to pass velocities from pmg and ase

+    def _wrap_positions(state: "SimState") -> None:
+        pbc = getattr(state, "pbc", None)
+        if pbc is None:
+            return
+        do_wrap = bool(pbc) if isinstance(pbc, bool) else bool(torch.as_tensor(pbc).any())
+        if not do_wrap:
+            return
+
+        state.positions = ts.transforms.pbc_wrap_batched(
+            positions=state.positions,
+            cell=state.cell,
+            system_idx=state.system_idx,
+            pbc=state.pbc,
+        )
+
     if isinstance(system, SimState):
-        return system.clone().to(device, dtype)
+        state = system.clone().to(device, dtype)
+        _wrap_positions(state)
+        return state

     if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system):
         if not all(state.n_systems == 1 for state in system):
@@ -1089,7 +1106,9 @@ def initialize_state(
                 "all states must have n_systems == 1. To fix this, you can split the "
                 "states into individual states with the split_state function."
             )
-        return ts.concatenate_states(system)
+        state = ts.concatenate_states(system)
+        _wrap_positions(state)
+        return state

     converters = [
         ("pymatgen.core", "Structure", ts.io.structures_to_state),
@@ -1107,7 +1126,9 @@ def initialize_state(
                 isinstance(system, list | tuple)
                 and all(isinstance(s, cls) for s in system)
             ):
-                return converter_func(system, device, dtype)
+                state = converter_func(system, device, dtype)
+                _wrap_positions(state)
+                return state
         except ImportError:
             continue
Patch: torch_sim/integrators/md.py
--- old/torch_sim/integrators/md.py
+++ new/torch_sim/integrators/md.py
@@ -9,7 +9,7 @@ from torch_sim.models.interface import ModelInterface
 from torch_sim.quantities import calc_kT
 from torch_sim.state import SimState
 from torch_sim.units import MetalUnits
-
+from torch_sim.transforms import pbc_wrap_batched

 @dataclass(kw_only=True)
 class MDState(SimState):
@@ -187,6 +187,20 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T:
     """
     new_positions = state.positions + state.velocities * dt
     state.set_constrained_positions(new_positions)
+
+    # Keep state.positions wrapped so model calls are always evaluated in the primary cell.
+    if getattr(state, "pbc", None) is not None:
+        # Wrap only if any periodic dimension exists.
+        pbc = state.pbc
+        do_wrap = bool(pbc) if isinstance(pbc, bool) else bool(torch.as_tensor(pbc).any())
+        if do_wrap:
+            state.positions = pbc_wrap_batched(
+                positions=state.positions,
+                cell=state.cell,
+                system_idx=state.system_idx,
+                pbc=state.pbc,
+            )
+
     return state
Patch: torch_sim/integrators/nvt.py
--- old/torch_sim/integrators/nvt.py
+++ new/torch_sim/integrators/nvt.py
@@ -19,7 +19,7 @@ from torch_sim.integrators.md import (
 from torch_sim.models.interface import ModelInterface
 from torch_sim.state import SimState
 from torch_sim.typing import StateDict
-
+from torch_sim.transforms import pbc_wrap_batched

 def _ou_step(
     state: MDState,
@@ -111,6 +111,16 @@ def nvt_langevin_init(
     if not isinstance(state, SimState):
         state = SimState(**state)

+    pbc = state.pbc
+    do_wrap = bool(pbc) if isinstance(pbc, bool) else bool(torch.as_tensor(pbc).any())
+    if do_wrap:
+        state.positions = pbc_wrap_batched(
+            positions=state.positions,
+            cell=state.cell,
+            system_idx=state.system_idx,
+            pbc=state.pbc,
+        )
+
     model_output = model(state)

     momenta = getattr(
@@ -288,6 +298,18 @@ def nvt_nose_hoover_init(
         - Chain variables evolve to maintain target temperature
         - Time-reversible when integrated with appropriate algorithms
     """
+
+    pbc = state.pbc
+    do_wrap = bool(pbc) if isinstance(pbc, bool) else bool(torch.as_tensor(pbc).any())
+    if do_wrap:
+        state.positions = pbc_wrap_batched(
+            positions=state.positions,
+            cell=state.cell,
+            system_idx=state.system_idx,
+            pbc=state.pbc,
+        )
+
     if tau is None:  # Set default tau if not provided
         tau = dt * 100.0

Unless I’m misunderstanding something, this looks like a serious issue that can lead to incorrect energies/forces under PBC.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions