diff --git a/src/galax/dynamics/_src/experimental/__init__.py b/src/galax/dynamics/_src/experimental/__init__.py index 8bca15b1..53cbc37e 100644 --- a/src/galax/dynamics/_src/experimental/__init__.py +++ b/src/galax/dynamics/_src/experimental/__init__.py @@ -1,6 +1,7 @@ """Experimental dynamics.""" -__all__ = ["integrate_orbit", "StreamSimulator"] +__all__ = ["integrate_orbit", "StreamSimulator", "Leapfrog"] from .integrate import integrate_orbit +from .leapfrog import Leapfrog from .stream import StreamSimulator diff --git a/src/galax/dynamics/_src/experimental/integrate.py b/src/galax/dynamics/_src/experimental/integrate.py index 4c0e70e8..20b51a6c 100644 --- a/src/galax/dynamics/_src/experimental/integrate.py +++ b/src/galax/dynamics/_src/experimental/integrate.py @@ -20,7 +20,6 @@ import galax.potential as gp import galax.utils.loop_strategies as lstrat from galax.dynamics._src.orbit.field_base import AbstractOrbitField -from galax.dynamics._src.orbit.field_hamiltonian import HamiltonianField BQParr: TypeAlias = tuple[Real[gdt.Qarr, "B"], Real[gdt.Parr, "B"]] @@ -376,6 +375,11 @@ def integrate_orbit( evaluation of the solution. """ + # Note: this is needed to prevent a circular import + from galax.dynamics._src.orbit.field_hamiltonian import ( + HamiltonianField, + ) + field = pot if isinstance(pot, AbstractOrbitField) else HamiltonianField(pot) terms = field.terms(solver) diff --git a/src/galax/dynamics/_src/experimental/leapfrog.py b/src/galax/dynamics/_src/experimental/leapfrog.py new file mode 100644 index 00000000..e34d4de1 --- /dev/null +++ b/src/galax/dynamics/_src/experimental/leapfrog.py @@ -0,0 +1,108 @@ +# ruff: noqa: ARG002 +""" +Note: This module implements a diffrax solver for Leapfrog integration. There is a +stalled PR to add a similar integrator to diffrax, so in the meantime we implement it +here. +""" # noqa: D205 + +from collections.abc import Callable +from typing import Any, ClassVar, TypeAlias + +from diffrax import SemiImplicitEuler +from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike +from diffrax._local_interpolation import LocalLinearInterpolation +from diffrax._solution import RESULTS +from diffrax._solver.base import AbstractSolver +from diffrax._term import AbstractTerm +from equinox.internal import ω # noqa: PLC2403 +from jaxtyping import ArrayLike, Float, PyTree + +_ErrorEstimate: TypeAlias = None +_SolverState: TypeAlias = None + +Ya: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] +Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] + + +class Leapfrog(AbstractSolver): # type: ignore[misc] + """Leapfrog (velocity Verlet) symplectic integrator. + + This is a 2nd order symplectic integration method. This integrator does not support + adaptive step sizing. This is either known as kick-drift-kick leapfrog or velocity + Verlet. + + Assuming that: + + x0, v0 = y0 + + and: + + f, g = terms + + This method computes the next step as: + + v_half = v0 + h/2 * g(t0, x0) + x1 = x0 + h * f(t0, v_half) + v1 = v_half + h/2 * g(t1, x1) + """ + + term_structure: ClassVar = (AbstractTerm, AbstractTerm) + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) + + def order(self, _: Any) -> int: + return 2 + + def init( + self, + terms: tuple[AbstractTerm, AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: tuple[Ya, Yb], + args: Args, + ) -> _SolverState: + return None + + def step( + self, + terms: tuple[AbstractTerm, AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: tuple[Ya, Yb], + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[tuple[Ya, Yb], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + f, g = terms + q0, p0 = y0 + h = t1 - t0 + + v_half = (p0**ω + 0.5 * h * g.vf(t0, q0, args) ** ω).ω + q1 = (q0**ω + h * f.vf(t0, v_half, args) ** ω).ω + p1 = (v_half**ω + 0.5 * h * g.vf(t1, q1, args) ** ω).ω + + y1 = (q1, p1) + dense_info = {"y0": y0, "y1": y1} + return y1, None, dense_info, None, RESULTS.successful + + def func( + self, + terms: tuple[AbstractTerm, AbstractTerm], + t0: RealScalarLike, + y0: tuple[Ya, Yb], + args: Args, + ) -> VF: + f, g = terms + q0, p0 = y0 + qdot = f.vf(t0, p0, args) + pdot = g.vf(t0, q0, args) + return qdot, pdot + + +Leapfrog.__init__.__doc__ = """**Arguments:** None""" + + +SymplecticSolverT: TypeAlias = Leapfrog | SemiImplicitEuler diff --git a/src/galax/dynamics/_src/orbit/field_hamiltonian.py b/src/galax/dynamics/_src/orbit/field_hamiltonian.py index 060c2167..48d05cf0 100644 --- a/src/galax/dynamics/_src/orbit/field_hamiltonian.py +++ b/src/galax/dynamics/_src/orbit/field_hamiltonian.py @@ -14,6 +14,7 @@ import galax.dynamics._src.custom_types as gdt import galax.potential as gp from .field_base import AbstractOrbitField +from galax.dynamics._src.experimental.leapfrog import SymplecticSolverT from galax.dynamics._src.utils import parse_to_t_y @@ -255,7 +256,7 @@ def dv_dt(self, t: gt.BBtSz0, xyz: gdt.BBtQarr, _: Any, /) -> gdt.BtAarr: @AbstractOrbitField.terms.dispatch # type: ignore[misc] def terms( self: HamiltonianField, - _: dfx.SemiImplicitEuler, + _: SymplecticSolverT, /, ) -> tuple[dfx.ODETerm, dfx.ODETerm]: r"""Return the AbstractTerm terms for the SemiImplicitEuler solver. diff --git a/tests/unit/dynamics/experimental/__init__.py b/tests/unit/dynamics/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/dynamics/experimental/test_leapfrog.py b/tests/unit/dynamics/experimental/test_leapfrog.py new file mode 100644 index 00000000..f5a9d6a7 --- /dev/null +++ b/tests/unit/dynamics/experimental/test_leapfrog.py @@ -0,0 +1,32 @@ +# ruff: noqa: ARG005 + +import diffrax as dfx +import jax.numpy as jnp +import pytest + +from galax.dynamics import experimental + + +def shaped_allclose(x, y, **kwargs): + return jnp.shape(x) == jnp.shape(y) and jnp.allclose(x, y, **kwargs) + + +@pytest.mark.parametrize("solver", [experimental.Leapfrog()]) +def test_symplectic_solvers(solver): + dq_dt = dfx.ODETerm(lambda t, p, args: p) + dp_dt = dfx.ODETerm(lambda t, q, args: -q) + y0 = (1.0, -0.5) + dt0 = 0.00001 + sol1 = dfx.diffeqsolve( + (dq_dt, dp_dt), + solver, + 0, + 1, + dt0, + y0, + max_steps=100000, + ) + term_combined = dfx.ODETerm(lambda t, y, args: (y[1], -y[0])) + sol2 = dfx.diffeqsolve(term_combined, dfx.Tsit5(), 0, 1, 0.001, y0) + assert shaped_allclose(sol1.ys[0], sol2.ys[0]) + assert shaped_allclose(sol1.ys[1], sol2.ys[1])