Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/galax/dynamics/_src/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Experimental dynamics."""

__all__ = ["integrate_orbit", "StreamSimulator"]
__all__ = ["integrate_orbit", "StreamSimulator", "Leapfrog"]

from .integrate import integrate_orbit
from .leapfrog import Leapfrog
from .stream import StreamSimulator
6 changes: 5 additions & 1 deletion src/galax/dynamics/_src/experimental/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import galax.potential as gp
import galax.utils.loop_strategies as lstrat
from galax.dynamics._src.orbit.field_base import AbstractOrbitField
from galax.dynamics._src.orbit.field_hamiltonian import HamiltonianField

BQParr: TypeAlias = tuple[Real[gdt.Qarr, "B"], Real[gdt.Parr, "B"]]

Expand Down Expand Up @@ -376,6 +375,11 @@ def integrate_orbit(
evaluation of the solution.

"""
# Note: this is needed to prevent a circular import
from galax.dynamics._src.orbit.field_hamiltonian import (
HamiltonianField,
)

field = pot if isinstance(pot, AbstractOrbitField) else HamiltonianField(pot)
terms = field.terms(solver)

Expand Down
108 changes: 108 additions & 0 deletions src/galax/dynamics/_src/experimental/leapfrog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ruff: noqa: ARG002
"""
Note: This module implements a diffrax solver for Leapfrog integration. There is a
stalled PR to add a similar integrator to diffrax, so in the meantime we implement it
here.
""" # noqa: D205

from collections.abc import Callable
from typing import Any, ClassVar, TypeAlias

from diffrax import SemiImplicitEuler
from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike
from diffrax._local_interpolation import LocalLinearInterpolation
from diffrax._solution import RESULTS
from diffrax._solver.base import AbstractSolver
from diffrax._term import AbstractTerm
from equinox.internal import ω # noqa: PLC2403
from jaxtyping import ArrayLike, Float, PyTree

_ErrorEstimate: TypeAlias = None
_SolverState: TypeAlias = None

Ya: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"]
Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"]


class Leapfrog(AbstractSolver): # type: ignore[misc]
"""Leapfrog (velocity Verlet) symplectic integrator.

This is a 2nd order symplectic integration method. This integrator does not support
adaptive step sizing. This is either known as kick-drift-kick leapfrog or velocity
Verlet.

Assuming that:

x0, v0 = y0

and:

f, g = terms

This method computes the next step as:

v_half = v0 + h/2 * g(t0, x0)
x1 = x0 + h * f(t0, v_half)
v1 = v_half + h/2 * g(t1, x1)
"""

term_structure: ClassVar = (AbstractTerm, AbstractTerm)
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, _: Any) -> int:
return 2

def init(
self,
terms: tuple[AbstractTerm, AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: tuple[Ya, Yb],
args: Args,
) -> _SolverState:
return None

def step(
self,
terms: tuple[AbstractTerm, AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: tuple[Ya, Yb],
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[tuple[Ya, Yb], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

f, g = terms
q0, p0 = y0
h = t1 - t0

v_half = (p0**ω + 0.5 * h * g.vf(t0, q0, args) ** ω).ω
q1 = (q0**ω + h * f.vf(t0, v_half, args) ** ω).ω
p1 = (v_half**ω + 0.5 * h * g.vf(t1, q1, args) ** ω).ω

y1 = (q1, p1)
dense_info = {"y0": y0, "y1": y1}
return y1, None, dense_info, None, RESULTS.successful

def func(
self,
terms: tuple[AbstractTerm, AbstractTerm],
t0: RealScalarLike,
y0: tuple[Ya, Yb],
args: Args,
) -> VF:
f, g = terms
q0, p0 = y0
qdot = f.vf(t0, p0, args)
pdot = g.vf(t0, q0, args)
return qdot, pdot


Leapfrog.__init__.__doc__ = """**Arguments:** None"""


SymplecticSolverT: TypeAlias = Leapfrog | SemiImplicitEuler
3 changes: 2 additions & 1 deletion src/galax/dynamics/_src/orbit/field_hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import galax.dynamics._src.custom_types as gdt
import galax.potential as gp
from .field_base import AbstractOrbitField
from galax.dynamics._src.experimental.leapfrog import SymplecticSolverT
from galax.dynamics._src.utils import parse_to_t_y


Expand Down Expand Up @@ -255,7 +256,7 @@ def dv_dt(self, t: gt.BBtSz0, xyz: gdt.BBtQarr, _: Any, /) -> gdt.BtAarr:
@AbstractOrbitField.terms.dispatch # type: ignore[misc]
def terms(
self: HamiltonianField,
_: dfx.SemiImplicitEuler,
_: SymplecticSolverT,
/,
) -> tuple[dfx.ODETerm, dfx.ODETerm]:
r"""Return the AbstractTerm terms for the SemiImplicitEuler solver.
Expand Down
Empty file.
32 changes: 32 additions & 0 deletions tests/unit/dynamics/experimental/test_leapfrog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ruff: noqa: ARG005

import diffrax as dfx
import jax.numpy as jnp
import pytest

from galax.dynamics import experimental


def shaped_allclose(x, y, **kwargs):
return jnp.shape(x) == jnp.shape(y) and jnp.allclose(x, y, **kwargs)


@pytest.mark.parametrize("solver", [experimental.Leapfrog()])
def test_symplectic_solvers(solver):
dq_dt = dfx.ODETerm(lambda t, p, args: p)
dp_dt = dfx.ODETerm(lambda t, q, args: -q)
y0 = (1.0, -0.5)
dt0 = 0.00001
sol1 = dfx.diffeqsolve(
(dq_dt, dp_dt),
solver,
0,
1,
dt0,
y0,
max_steps=100000,
)
term_combined = dfx.ODETerm(lambda t, y, args: (y[1], -y[0]))
sol2 = dfx.diffeqsolve(term_combined, dfx.Tsit5(), 0, 1, 0.001, y0)
assert shaped_allclose(sol1.ys[0], sol2.ys[0])
assert shaped_allclose(sol1.ys[1], sol2.ys[1])
Loading