Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/galax/potential/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from galax.potential._src.params.utils import all_parameters, all_vars
from galax.utils._jax import vectorize_method
from galax.utils.dataclasses import ModuleMeta
from galax.utils.defaults import DEFAULT_TIME

if TYPE_CHECKING:
import galax.dynamics # noqa: ICN001
Expand Down Expand Up @@ -110,7 +111,9 @@ def from_(
# Potential energy

@abc.abstractmethod
def _potential(self, q: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtQorVSz0:
def _potential(
self, q: gt.BBtQorVSz3, t: gt.BBtQorVSz0 = DEFAULT_TIME, /
) -> gt.BBtQorVSz0:
"""Compute the potential energy at the given position(s).

This method MUST be implemented by subclasses.
Expand All @@ -123,9 +126,10 @@ def _potential(self, q: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtQorVSz0:
q : Quantity[float, (3,), 'length']
The Cartesian position at which to compute the value of the
potential. The units are the same as the potential's unit system.
t : Quantity[float, (), 'time']
t : Quantity[float, (), 'time'], optional
The time at which to compute the value of the potential.
The units are the same as the potential's unit system.
Defaults to 0 if not provided.

Returns
-------
Expand Down Expand Up @@ -169,7 +173,7 @@ def __call__(self, *args: Any) -> Any:
@vectorize_method(signature="(3),()->(3)")
@ft.partial(jax.jit)
def _gradient(
self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0, /
self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0 = DEFAULT_TIME, /
) -> gt.FloatSz3:
"""See ``gradient``."""
xyz = u.ustrip(AllowValue, self.units[DimL], xyz)
Expand All @@ -190,7 +194,7 @@ def gradient(self, *args: Any, **kwargs: Any) -> Any:
@vectorize_method(signature="(3),()->()")
@ft.partial(jax.jit)
def _laplacian(
self, xyz: gt.FloatQuSz3 | gt.FloatSz3, /, t: gt.QuSz0 | gt.Sz0
self, xyz: gt.FloatQuSz3 | gt.FloatSz3, /, t: gt.QuSz0 | gt.Sz0 = DEFAULT_TIME
) -> gt.FloatSz0:
"""See ``laplacian``."""
xyz = u.ustrip(AllowValue, self.units[DimL], xyz)
Expand All @@ -209,7 +213,9 @@ def laplacian(self, *args: Any, **kwargs: Any) -> u.Quantity["1/s^2"] | Array:
# Density

@ft.partial(jax.jit)
def _density(self, q: gt.BBtQuSz3, t: gt.BBtQuSz0 | gt.QuSz0, /) -> gt.BBtFloatSz0:
def _density(
self, q: gt.BBtQuSz3, t: gt.BBtQuSz0 | gt.QuSz0 = DEFAULT_TIME, /
) -> gt.BBtFloatSz0:
"""See ``density``."""
# Note: trace(jacobian(gradient)) is faster than trace(hessian(energy))
laplacian = self._laplacian(q, t)
Expand All @@ -228,7 +234,7 @@ def density(self, *args: Any, **kwargs: Any) -> gt.BBtFloatSz0 | gt.BBtFloatQuSz
@vectorize_method(signature="(3),()->(3,3)")
@ft.partial(jax.jit)
def _hessian(
self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0 | gt.Sz0, /
self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0 | gt.Sz0 = DEFAULT_TIME, /
) -> gt.Sz33:
"""See ``hessian``."""
xyz = u.ustrip(AllowValue, self.units[DimL], xyz)
Expand Down
163 changes: 157 additions & 6 deletions src/galax/potential/_src/register_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .base import AbstractPotential
from .utils import parse_to_xyz_t
from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims
from galax.utils.defaults import DEFAULT_TIME

# =============================================================================
# Potential Energy
Expand Down Expand Up @@ -47,11 +48,45 @@ def potential(
@dispatch # special-case Array input to not return Quantity
@ft.partial(jax.jit, inline=True)
def potential(
pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0
pot: AbstractPotential,
xyz: gt.XYZArrayLike,
/,
*,
t: gt.BBtLikeSz0 = DEFAULT_TIME.value,
) -> gt.BBtSz0:
return api.potential(pot, xyz, t)


# ---------------------------
# Quantity


@dispatch
@ft.partial(jax.jit, inline=True)
def potential(
pot: AbstractPotential,
q: u.AbstractQuantity,
/,
*,
t: u.AbstractQuantity = DEFAULT_TIME,
) -> Real[u.Quantity["specific energy"], "*#batch"]:
"""Compute from a quantity."""
q, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame
phi = pot._potential(q, t) # noqa: SLF001
return u.Quantity(phi, pot.units["specific energy"])


@dispatch
@ft.partial(jax.jit, inline=True)
def potential(
pot: AbstractPotential, q: u.AbstractQuantity, t: u.AbstractQuantity, /
) -> Real[u.Quantity["specific energy"], "*#batch"]:
"""Compute the potential energy at the given position(s)."""
q, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame
phi = pot._potential(q, t) # noqa: SLF001
return u.Quantity(phi, pot.units["specific energy"])


# ---------------------------


Expand All @@ -60,6 +95,7 @@ def potential(
pot: AbstractPotential, tq: Any, /, *, t: Any = None
) -> Real[u.Quantity["specific energy"], "*#batch"]:
"""Compute from a q + t object."""
# Note: No default time here because if not specified, use the time on the tq object
return api.potential(pot, tq, t)


Expand Down Expand Up @@ -100,7 +136,11 @@ def gradient(
@dispatch # special-case Array input to not return Quantity
@ft.partial(jax.jit, inline=True)
def gradient(
pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0
pot: AbstractPotential,
xyz: gt.XYZArrayLike,
/,
*,
t: gt.BBtLikeSz0 = DEFAULT_TIME.value,
) -> gt.BBtSz3:
return api.gradient(pot, xyz, t)

Expand All @@ -112,7 +152,11 @@ def gradient(
@dispatch
@ft.partial(jax.jit, inline=True)
def gradient(
pot: AbstractPotential, xyz: u.AbstractQuantity, /, *, t: u.AbstractQuantity
pot: AbstractPotential,
xyz: u.AbstractQuantity,
/,
*,
t: u.AbstractQuantity = DEFAULT_TIME,
) -> Real[u.Quantity["acceleration"], "*#batch 3"]:
"""Compute from a q + t object."""
xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame
Expand Down Expand Up @@ -140,6 +184,7 @@ def gradient(
pot: AbstractPotential, tq: Any, /, *, t: Any = None
) -> cx.vecs.CartesianAcc3D:
"""Compute from a q + t object."""
# Note: No default time here because if not specified, use the time on the tq object
xyz, t = parse_to_xyz_t(None, tq, t, ustrip=pot.units, dtype=float) # TODO: frame
grad = pot._gradient(xyz, t) # noqa: SLF001
return cx.vecs.CartesianAcc3D.from_(grad, pot.units["acceleration"])
Expand Down Expand Up @@ -180,11 +225,45 @@ def laplacian(
@dispatch # special-case Array input to not return Quantity
@ft.partial(jax.jit, inline=True)
def laplacian(
pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0
pot: AbstractPotential,
xyz: gt.XYZArrayLike,
/,
*,
t: gt.BBtLikeSz0 = DEFAULT_TIME.value,
) -> gt.BBtSz0:
return api.laplacian(pot, xyz, t)


# ---------------------------
# Quantity


@dispatch
@ft.partial(jax.jit, inline=True)
def laplacian(
pot: AbstractPotential,
xyz: u.AbstractQuantity,
/,
*,
t: u.AbstractQuantity = DEFAULT_TIME,
) -> Real[u.Quantity["frequency drift"], "*#batch"]:
"""Compute from a quantity object."""
xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame
lapl = pot._laplacian(xyz, t) # noqa: SLF001
return u.Quantity.from_(lapl, pot.units["frequency drift"])


@dispatch
@ft.partial(jax.jit, inline=True)
def laplacian(
pot: AbstractPotential, q: u.AbstractQuantity, t: u.AbstractQuantity, /
) -> Real[u.Quantity["frequency drift"], "*#batch"]:
"""Compute from a quantity object and a time."""
xyz, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame
lapl = pot._laplacian(xyz, t) # noqa: SLF001
return u.Quantity.from_(lapl, pot.units["frequency drift"])


# ---------------------------


Expand Down Expand Up @@ -235,11 +314,45 @@ def density(
@dispatch # special-case Array input to not return Quantity
@ft.partial(jax.jit, inline=True)
def density(
pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0
pot: AbstractPotential,
xyz: gt.XYZArrayLike,
/,
*,
t: gt.BBtLikeSz0 = DEFAULT_TIME.value,
) -> gt.BBtSz0:
return api.density(pot, xyz, t)


# ---------------------------
# Quantity


@dispatch
@ft.partial(jax.jit, inline=True)
def density(
pot: AbstractPotential,
xyz: u.AbstractQuantity,
/,
*,
t: u.AbstractQuantity = DEFAULT_TIME,
) -> Real[u.Quantity["mass density"], "*#batch"]:
"""Compute from a quantity object."""
xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame
rho = pot._density(xyz, t) # noqa: SLF001
return u.Quantity.from_(rho, pot.units["mass density"])


@dispatch
@ft.partial(jax.jit, inline=True)
def density(
pot: AbstractPotential, q: u.AbstractQuantity, t: u.AbstractQuantity, /
) -> Real[u.Quantity["mass density"], "*#batch"]:
"""Compute from a quantity object and a time."""
xyz, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame
rho = pot._density(xyz, t) # noqa: SLF001
return u.Quantity.from_(rho, pot.units["mass density"])


# ---------------------------


Expand Down Expand Up @@ -290,11 +403,49 @@ def hessian(
@dispatch # special-case Array input to not return Quantity
@ft.partial(jax.jit, inline=True)
def hessian(
pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0
pot: AbstractPotential,
xyz: gt.XYZArrayLike,
/,
*,
t: gt.BBtLikeSz0 = DEFAULT_TIME.value,
) -> gt.BBtSz33:
return api.hessian(pot, xyz, t)


# ---------------------------
# Quantity


@dispatch
@ft.partial(jax.jit, inline=True)
def hessian(
pot: AbstractPotential,
xyz: u.AbstractQuantity,
/,
*,
t: u.AbstractQuantity = DEFAULT_TIME,
) -> Real[u.Quantity["frequency drift"], "*#batch 3 3"]:
"""Compute from a quantity object."""
xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame
phi = pot._hessian(xyz, t) # noqa: SLF001
return u.Quantity.from_(phi, pot.units["frequency drift"])


@dispatch
@ft.partial(jax.jit, inline=True)
def hessian(
pot: AbstractPotential,
xyz: u.AbstractQuantity,
/,
*,
t: u.AbstractQuantity = DEFAULT_TIME,
) -> Real[u.Quantity["frequency drift"], "*#batch 3 3"]:
"""Compute from a quantity object."""
xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame
phi = pot._hessian(xyz, t) # noqa: SLF001
return u.Quantity.from_(phi, pot.units["frequency drift"])


# ---------------------------


Expand Down
19 changes: 14 additions & 5 deletions src/galax/potential/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import galax._custom_types as gt
import galax.coordinates as gc
from galax.utils.defaults import DEFAULT_TIME

OptUSys: TypeAlias = u.AbstractUnitSystem | None

Expand Down Expand Up @@ -204,7 +205,7 @@ def parse_to_xyz_t(
def parse_to_xyz_t(
to_frame: cxf.AbstractReferenceFrame | None,
xyz: gt.XYZArrayLike,
t: gt.BBtLikeSz0, # TODO: consider also "*#batch 1"
t: gt.BBtLikeSz0 | None, # TODO: consider also "*#batch 1"
/,
*,
dtype: Any = None,
Expand All @@ -213,7 +214,11 @@ def parse_to_xyz_t(
"""Parse input arguments to position & time."""
# Process the input arguments into arrays
xyz = jnp.asarray(xyz, dtype=dtype)
t = jnp.asarray(t, dtype=dtype)
t = (
jnp.asarray(t, dtype=dtype)
if t is not None
else jnp.asarray(DEFAULT_TIME.value, dtype=dtype)
)

# The coordinates are assumed to be in the simulation frame and may need to
# be transformed to the target frame.
Expand Down Expand Up @@ -263,20 +268,24 @@ def parse_to_xyz_t(

@coord_dispatcher.multi(
(cxf.AbstractReferenceFrame | None, gt.BBtQuSz3, gt.BBtQuSz0),
(cxf.AbstractReferenceFrame | None, gt.BBtQuSz3, gt.BBtSz0 | float | int),
(cxf.AbstractReferenceFrame | None, gt.BBtQuSz3, gt.BBtSz0 | float | int | None),
)
def parse_to_xyz_t(
to_frame: cxf.AbstractReferenceFrame | None,
xyz: gt.BBtQorVSz3,
t: gt.BBtQorVSz0 | float | int,
t: gt.BBtQorVSz0 | float | int | None,
/,
*,
dtype: Any = None,
ustrip: OptUSys = None,
) -> tuple[gt.BBtQorVSz3, gt.BBtQorVSz0]:
"""Parse input arguments to position & time."""
xyz = jnp.asarray(xyz, dtype=dtype)
t = jnp.asarray(t, dtype=dtype)
t = (
jnp.asarray(t, dtype=dtype)
if t is not None
else jnp.asarray(DEFAULT_TIME, dtype=dtype)
)

if ustrip is not None:
xyz = u.ustrip(AllowValue, ustrip["length"], xyz)
Expand Down
5 changes: 4 additions & 1 deletion src/galax/potential/_src/xfm/xop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import galax._custom_types as gt
from .base import AbstractTransformedPotential
from galax.potential._src.base import AbstractPotential
from galax.utils.defaults import DEFAULT_TIME


@final
Expand Down Expand Up @@ -189,7 +190,9 @@ class TransformedPotential(AbstractTransformedPotential):

"""

def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0:
def _potential(
self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0 = DEFAULT_TIME, /
) -> gt.BBtSz0:
"""Compute the potential energy at the given position(s).

This method applies the operators to the coordinates and then evaluates
Expand Down
8 changes: 8 additions & 0 deletions src/galax/utils/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Globally assumed default values across Galax."""

__all__ = ["DEFAULT_TIME"]

import unxt as u

# Default time value used when time is optional
DEFAULT_TIME = u.Quantity(0.0, "Myr")
Loading
Loading