Skip to content

Commit f5096b0

Browse files
authored
Merge pull request #362 from libAtoms/MD_Langevin_BAOAB
Add support for Langevin_BAOAB in wfl.generate.md, and support for overriding arbitrary kwargs per-config
2 parents 1f06bfe + 4b38853 commit f5096b0

File tree

7 files changed

+408
-210
lines changed

7 files changed

+408
-210
lines changed

.github/workflows/pytests.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ jobs:
4747
- name: Install latest ASE from pypi
4848
run: |
4949
echo PIP_CONSTRAINT $PIP_CONSTRAINT
50-
python3 -m pip install ase
50+
# avoid broken extxyz writing (3.25, fixed in 3.26)
51+
# avoid broken optimizer.converged() (3.26) https://gitlab.com/ase/ase/-/issues/1744
52+
python3 -m pip install 'ase<3.25'
53+
#
5154
echo -n "ASE VERSION "
5255
python3 -c "import ase; print(ase.__file__, ase.__version__)"
5356
python3 -c "import numpy; print('numpy version', numpy.__version__)"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dynamic = ["version"]
1515
reactions_iter_fit = "wfl.cli.reactions_iter_fit:cli"
1616

1717
[tool.setuptools.packages.find]
18-
exclude = [ "test*" ]
18+
include = [ "wfl*" ]
1919

2020
[tool.setuptools.dynamic]
2121
version = {attr = "wfl.__version__"}

tests/local_scripts/complete_pytest.tin

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
module purge
44
# module load compiler/gnu python/system python_extras/quippy lapack/mkl
55
module load compiler/gnu python python_extras/quippy lapack/mkl
6+
# for wfl dependencies
7+
module load python_extras/wif
68
module load python_extras/torch/cpu
79

810
if [ -z "$WFL_PYTEST_EXPYRE_INFO" ]; then
@@ -61,7 +63,7 @@ fi
6163

6264
mkdir -p $pytest_dir
6365

64-
pytest -v -s --basetemp $pytest_dir ${runremote} --runslow --runperf -rxXs "$@" >> complete_pytest.tin.out 2>&1
66+
python3 -m pytest -v -s --basetemp $pytest_dir ${runremote} --runslow --runperf -rxXs "$@" >> complete_pytest.tin.out 2>&1
6567

6668
l=`egrep '^=.*(passed|failed|skipped|xfailed|error).* in ' complete_pytest.tin.out`
6769

tests/test_md.py

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from pytest import approx
22
import pytest
33

4-
import numpy as np
54
import os
65
from pathlib import Path
6+
import json
7+
8+
import numpy as np
79

810
from ase import Atoms
911
import ase.io
@@ -17,6 +19,10 @@
1719
from wfl.configset import ConfigSet, OutputSpec
1820
from wfl.generate.md.abort import AbortOnCollision, AbortOnLowEnergy
1921

22+
try:
23+
from wif.Langevin_BAOAB import Langevin_BAOAB
24+
except ImportError:
25+
Langevin_BAOAB = None
2026

2127
def select_every_10_steps_for_tests_during(at):
2228
return at.info.get("MD_step", 1) % 10 == 0
@@ -52,7 +58,6 @@ def test_NVE(cu_slab):
5258
temperature=500.0, rng=np.random.default_rng(1))
5359

5460
atoms_traj = list(atoms_traj)
55-
atoms_final = atoms_traj[-1]
5661

5762
assert len(atoms_traj) == 301
5863

@@ -68,27 +73,131 @@ def test_NVT_const_T(cu_slab):
6873
temperature=500.0, temperature_tau=30.0, rng=np.random.default_rng(1))
6974

7075
atoms_traj = list(atoms_traj)
71-
atoms_final = atoms_traj[-1]
7276

7377
assert len(atoms_traj) == 301
7478
assert all([at.info['MD_temperature_K'] == 500.0 for at in atoms_traj])
79+
assert np.all(atoms_traj[0].cell == atoms_traj[-1].cell)
7580

7681

7782
def test_NVT_Langevin_const_T(cu_slab):
78-
7983
calc = EMT()
8084

8185
inputs = ConfigSet(cu_slab)
8286
outputs = OutputSpec()
8387

8488
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin", steps=300, dt=1.0,
85-
temperature=500.0, temperature_tau=100/fs, rng=np.random.default_rng(1))
89+
temperature=500.0, temperature_tau=100/fs, rng=np.random.default_rng(1))
8690

8791
atoms_traj = list(atoms_traj)
88-
atoms_final = atoms_traj[-1]
8992

9093
assert len(atoms_traj) == 301
9194
assert all([at.info['MD_temperature_K'] == 500.0 for at in atoms_traj])
95+
assert np.all(atoms_traj[0].cell == atoms_traj[-1].cell)
96+
97+
98+
def test_NPT_Langevin_fail(cu_slab):
99+
calc = EMT()
100+
101+
inputs = ConfigSet(cu_slab)
102+
outputs = OutputSpec()
103+
104+
with pytest.raises(ValueError):
105+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin", steps=300, dt=1.0,
106+
temperature=500.0, temperature_tau=100/fs, pressure=0.0,
107+
rng=np.random.default_rng(1))
108+
109+
110+
def test_NPT_Berendsen_hydro_F_fail(cu_slab):
111+
calc = EMT()
112+
113+
inputs = ConfigSet(cu_slab)
114+
outputs = OutputSpec()
115+
116+
with pytest.raises(ValueError):
117+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Berendsen", steps=300, dt=1.0,
118+
temperature=500.0, temperature_tau=100/fs, pressure=0.0, hydrostatic=False,
119+
rng=np.random.default_rng(1))
120+
121+
122+
def test_NPT_Berendsen_NPH_fail(cu_slab):
123+
calc = EMT()
124+
125+
inputs = ConfigSet(cu_slab)
126+
outputs = OutputSpec()
127+
128+
with pytest.raises(ValueError):
129+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Berendsen", steps=300, dt=1.0,
130+
pressure=0.0,
131+
rng=np.random.default_rng(1))
132+
133+
134+
def test_NPT_Berendsen(cu_slab):
135+
calc = EMT()
136+
137+
inputs = ConfigSet(cu_slab)
138+
outputs = OutputSpec()
139+
140+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Berendsen", steps=300, dt=1.0,
141+
temperature=500.0, temperature_tau=100/fs, pressure=0.0,
142+
rng=np.random.default_rng(1))
143+
144+
atoms_traj = list(atoms_traj)
145+
print("I cell", atoms_traj[0].cell)
146+
print("F cell", atoms_traj[1].cell)
147+
148+
assert len(atoms_traj) == 301
149+
assert all([at.info['MD_temperature_K'] == 500.0 for at in atoms_traj])
150+
assert np.any(atoms_traj[0].cell != atoms_traj[-1].cell)
151+
152+
cell_f = atoms_traj[0].cell[0, 0] / atoms_traj[-1].cell[0, 0]
153+
assert np.allclose(atoms_traj[0].cell, atoms_traj[-1].cell * cell_f)
154+
155+
156+
@pytest.mark.skipif(Langevin_BAOAB is None, reason="No Langevin_BAOAB available")
157+
def test_NPT_Langevin_BAOAB(cu_slab):
158+
calc = EMT()
159+
160+
inputs = ConfigSet(cu_slab)
161+
outputs = OutputSpec()
162+
163+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin_BAOAB", steps=300, dt=1.0,
164+
temperature=500.0, temperature_tau=100/fs, pressure=0.0,
165+
rng=np.random.default_rng(1))
166+
167+
atoms_traj = list(atoms_traj)
168+
print("I cell", atoms_traj[0].cell)
169+
print("F cell", atoms_traj[1].cell)
170+
171+
assert len(atoms_traj) == 301
172+
assert all([at.info['MD_temperature_K'] == 500.0 for at in atoms_traj])
173+
assert np.any(atoms_traj[0].cell != atoms_traj[-1].cell)
174+
175+
cell_f = atoms_traj[0].cell[0, 0] / atoms_traj[-1].cell[0, 0]
176+
assert np.allclose(atoms_traj[0].cell, atoms_traj[-1].cell * cell_f)
177+
178+
179+
@pytest.mark.skipif(Langevin_BAOAB is None, reason="No Langevin_BAOAB available")
180+
def test_NPT_Langevin_BAOAB_hydro_F(cu_slab):
181+
calc = EMT()
182+
183+
inputs = ConfigSet(cu_slab)
184+
outputs = OutputSpec()
185+
186+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin_BAOAB", steps=300, dt=1.0,
187+
temperature=500.0, temperature_tau=100/fs, pressure=0.0, hydrostatic=False,
188+
rng=np.random.default_rng(1))
189+
190+
atoms_traj = list(atoms_traj)
191+
print("I cell", atoms_traj[0].cell)
192+
print("F cell", atoms_traj[1].cell)
193+
194+
assert len(atoms_traj) == 301
195+
assert all([at.info['MD_temperature_K'] == 500.0 for at in atoms_traj])
196+
assert np.any(atoms_traj[0].cell != atoms_traj[-1].cell)
197+
198+
cell_f = atoms_traj[0].cell[0, 0] / atoms_traj[-1].cell[0, 0]
199+
assert not np.allclose(atoms_traj[0].cell, atoms_traj[-1].cell * cell_f)
200+
92201

93202

94203
def test_NVT_Langevin_const_T_per_config(cu_slab):
@@ -99,7 +208,7 @@ def test_NVT_Langevin_const_T_per_config(cu_slab):
99208
outputs = OutputSpec()
100209

101210
for at_i, at in enumerate(inputs):
102-
at.info["WFL_MD_TEMPERATURE"] = 500 + at_i * 100
211+
at.info["WFL_MD_KWARGS"] = json.dumps({'temperature': 500 + at_i * 100})
103212

104213
n_steps = 30
105214

wfl/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.3"
1+
__version__ = "0.3.4"

0 commit comments

Comments
 (0)