Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
Ralston,
ReversibleHeun,
SemiImplicitEuler,
StormerVerlet,
Sil3,
StratonovichMilstein,
Tsit5,
Expand Down
1 change: 1 addition & 0 deletions diffrax/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@
MultiButcherTableau,
)
from .semi_implicit_euler import SemiImplicitEuler
from .stormer_verlet import StormerVerlet
from .sil3 import Sil3
from .tsit5 import Tsit5
78 changes: 78 additions & 0 deletions diffrax/solver/stormer_verlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Tuple

from equinox.internal import ω

from ..custom_types import Bool, DenseInfo, PyTree, Scalar
from ..local_interpolation import LocalLinearInterpolation
from ..solution import RESULTS
from ..term import AbstractTerm
from .base import AbstractSolver

_ErrorEstimate = None
_SolverState = None

class StormerVerlet(AbstractSolver):
""" Störmer-Verlet method.

Symplectic method. Does not support adaptive step sizing. Uses 1st order local
linear interpolation for dense/ts output.
"""

term_structure = (AbstractTerm, AbstractTerm)
interpolation_cls = LocalLinearInterpolation

def order(self, terms):
return 2

def init(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
t1: Scalar,
y0: PyTree,
args: PyTree,
) -> _SolverState:
return None

def step(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
t1: Scalar,
y0: Tuple[PyTree, PyTree],
args: PyTree,
solver_state: _SolverState,
made_jump: Bool,
) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

term_1, term_2 = terms
y0_1, y0_2 = y0
midpoint = (t1 + t0)/2

control1_half_1 = term_1.contr(t0, midpoint)
control1_half_2 = term_1.contr(midpoint, t1)
control2 = term_2.contr(t0, t1)

yhalf_1 = (y0_1 ** ω + term_1.vf_prod(t0, y0_2, args, control1_half_1) ** ω).ω
y1_2 = (y0_2 ** ω + term_2.vf_prod(midpoint, yhalf_1, args, control2) ** ω).ω
y1_1 = (yhalf_1 ** ω + term_1.vf_prod(t1, y1_2, args, control1_half_2 ** ω)).ω

y1 = (y1_1, y1_2)
dense_info = dict(y0=y0, y1=y1)
return y1, None, dense_info, None, RESULTS.successful

def func(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
y0: Tuple[PyTree, PyTree],
args: PyTree
) -> Tuple[PyTree, PyTree]:
term_1, term_2 = terms
y0_1, y0_2 = y0
f1 = term_1.func(t0, y0_2, args)
f2 = term_2.func(t0, y0_1, args)
return (f1, f2)


6 changes: 5 additions & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax.random as jrandom
import jax.tree_util as jtu


all_ode_solvers = (
diffrax.Bosh3(),
diffrax.Dopri5(),
Expand All @@ -32,6 +31,11 @@
diffrax.KenCarp5(),
)

all_symplectic_solvers = (
diffrax.SemiImplicitEuler(),
diffrax.StormerVerlet(),
)


def implicit_tol(solver):
if isinstance(solver, diffrax.AbstractImplicitSolver):
Expand Down
59 changes: 57 additions & 2 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .helpers import (
all_ode_solvers,
all_split_solvers,
all_symplectic_solvers,
implicit_tol,
random_pytree,
shaped_allclose,
Expand Down Expand Up @@ -165,6 +166,59 @@ def f(t, y, args):
assert -0.9 < order - solver.order(term) < 0.9


@pytest.mark.parametrize("solver", all_symplectic_solvers)
def test_symplectic_ode_order(solver):
solver = implicit_tol(solver)
key = jrandom.PRNGKey(17)
p_key, q_key, k_key = jrandom.split(key, 3)
p0 = jrandom.uniform(p_key, shape=(), minval=0, maxval=1)
q0 = jrandom.uniform(q_key, shape=(), minval=0, maxval=1)
k = jrandom.uniform(k_key, shape=(), minval=0.1, maxval=10)
y0 = (p0, q0)
t0 = 0
t1 = 4

def p_vector_field(t, q, k):
return q

def q_vector_field(t, p, k):
return -k * p

def analytic_solution(t, k, p0, q0):
φ = jnp.sqrt(k)
p_t = p0 * jnp.cos(φ * t) + (q0/φ) * jnp.sin(φ * t)
q_t = -p0 * φ * jnp.sin(φ * t) + q0 * jnp.cos(φ * t)
return p_t, q_t


term = (
diffrax.ODETerm(p_vector_field),
diffrax.ODETerm(q_vector_field),
)

true_pT, true_qT = analytic_solution(t1, k, p0, q0)
exponents = []
errors_p = []
errors_q = []
for exponent in [0, -1, -2, -3, -4, -6, -8, -12]:
dt0 = 2**exponent
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, k, max_steps=None)
pT, qT = sol.ys
error_p = jnp.sum(jnp.abs(pT - true_pT))
error_q = jnp.sum(jnp.abs(qT - true_qT))
if error_p < 2**-28 and error_q < 2**-28:
break
exponents.append(exponent)
errors_p.append(jnp.log2(error_q))
errors_q.append(jnp.log2(error_q))

order_p = scipy.stats.linregress(exponents, errors_p). slope
order_q = scipy.stats.linregress(exponents, errors_q). slope
# Same wide range as for general ODE solvers, but we
# require this approximate order both for `p` and `q`
assert -0.9 < order_p - solver.order(term) < 0.9
assert -0.9 < order_q - solver.order(term) < 0.9

def _squareplus(x):
return 0.5 * (x + jnp.sqrt(x**2 + 4))

Expand Down Expand Up @@ -338,14 +392,15 @@ def f(t, y, args):
assert shaped_allclose(sol1.derivative(ti), -sol2.derivative(-ti))


def test_semi_implicit_euler():
@pytest.mark.parametrize("solver", all_symplectic_solvers)
def test_symplectic_solvers(solver):
term1 = diffrax.ODETerm(lambda t, y, args: -y)
term2 = diffrax.ODETerm(lambda t, y, args: y)
y0 = (1.0, -0.5)
dt0 = 0.00001
sol1 = diffrax.diffeqsolve(
(term1, term2),
diffrax.SemiImplicitEuler(),
solver,
0,
1,
dt0,
Expand Down