diff --git a/src/galax/potential/_src/base.py b/src/galax/potential/_src/base.py index 99b0bdcd0..35460aea0 100644 --- a/src/galax/potential/_src/base.py +++ b/src/galax/potential/_src/base.py @@ -27,6 +27,7 @@ from galax.potential._src.params.utils import all_parameters, all_vars from galax.utils._jax import vectorize_method from galax.utils.dataclasses import ModuleMeta +from galax.utils.defaults import DEFAULT_TIME if TYPE_CHECKING: import galax.dynamics # noqa: ICN001 @@ -110,7 +111,9 @@ def from_( # Potential energy @abc.abstractmethod - def _potential(self, q: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtQorVSz0: + def _potential( + self, q: gt.BBtQorVSz3, t: gt.BBtQorVSz0 = DEFAULT_TIME, / + ) -> gt.BBtQorVSz0: """Compute the potential energy at the given position(s). This method MUST be implemented by subclasses. @@ -123,9 +126,10 @@ def _potential(self, q: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtQorVSz0: q : Quantity[float, (3,), 'length'] The Cartesian position at which to compute the value of the potential. The units are the same as the potential's unit system. - t : Quantity[float, (), 'time'] + t : Quantity[float, (), 'time'], optional The time at which to compute the value of the potential. The units are the same as the potential's unit system. + Defaults to 0 if not provided. Returns ------- @@ -169,7 +173,7 @@ def __call__(self, *args: Any) -> Any: @vectorize_method(signature="(3),()->(3)") @ft.partial(jax.jit) def _gradient( - self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0, / + self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0 = DEFAULT_TIME, / ) -> gt.FloatSz3: """See ``gradient``.""" xyz = u.ustrip(AllowValue, self.units[DimL], xyz) @@ -190,7 +194,7 @@ def gradient(self, *args: Any, **kwargs: Any) -> Any: @vectorize_method(signature="(3),()->()") @ft.partial(jax.jit) def _laplacian( - self, xyz: gt.FloatQuSz3 | gt.FloatSz3, /, t: gt.QuSz0 | gt.Sz0 + self, xyz: gt.FloatQuSz3 | gt.FloatSz3, /, t: gt.QuSz0 | gt.Sz0 = DEFAULT_TIME ) -> gt.FloatSz0: """See ``laplacian``.""" xyz = u.ustrip(AllowValue, self.units[DimL], xyz) @@ -209,7 +213,9 @@ def laplacian(self, *args: Any, **kwargs: Any) -> u.Quantity["1/s^2"] | Array: # Density @ft.partial(jax.jit) - def _density(self, q: gt.BBtQuSz3, t: gt.BBtQuSz0 | gt.QuSz0, /) -> gt.BBtFloatSz0: + def _density( + self, q: gt.BBtQuSz3, t: gt.BBtQuSz0 | gt.QuSz0 = DEFAULT_TIME, / + ) -> gt.BBtFloatSz0: """See ``density``.""" # Note: trace(jacobian(gradient)) is faster than trace(hessian(energy)) laplacian = self._laplacian(q, t) @@ -228,7 +234,7 @@ def density(self, *args: Any, **kwargs: Any) -> gt.BBtFloatSz0 | gt.BBtFloatQuSz @vectorize_method(signature="(3),()->(3,3)") @ft.partial(jax.jit) def _hessian( - self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0 | gt.Sz0, / + self, xyz: gt.FloatQuSz3 | gt.FloatSz3, t: gt.QuSz0 | gt.Sz0 = DEFAULT_TIME, / ) -> gt.Sz33: """See ``hessian``.""" xyz = u.ustrip(AllowValue, self.units[DimL], xyz) diff --git a/src/galax/potential/_src/register_funcs.py b/src/galax/potential/_src/register_funcs.py index e0a9891ec..6f75527cd 100644 --- a/src/galax/potential/_src/register_funcs.py +++ b/src/galax/potential/_src/register_funcs.py @@ -18,6 +18,7 @@ from .base import AbstractPotential from .utils import parse_to_xyz_t from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims +from galax.utils.defaults import DEFAULT_TIME # ============================================================================= # Potential Energy @@ -47,11 +48,45 @@ def potential( @dispatch # special-case Array input to not return Quantity @ft.partial(jax.jit, inline=True) def potential( - pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0 + pot: AbstractPotential, + xyz: gt.XYZArrayLike, + /, + *, + t: gt.BBtLikeSz0 = DEFAULT_TIME.value, ) -> gt.BBtSz0: return api.potential(pot, xyz, t) +# --------------------------- +# Quantity + + +@dispatch +@ft.partial(jax.jit, inline=True) +def potential( + pot: AbstractPotential, + q: u.AbstractQuantity, + /, + *, + t: u.AbstractQuantity = DEFAULT_TIME, +) -> Real[u.Quantity["specific energy"], "*#batch"]: + """Compute from a quantity.""" + q, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame + phi = pot._potential(q, t) # noqa: SLF001 + return u.Quantity(phi, pot.units["specific energy"]) + + +@dispatch +@ft.partial(jax.jit, inline=True) +def potential( + pot: AbstractPotential, q: u.AbstractQuantity, t: u.AbstractQuantity, / +) -> Real[u.Quantity["specific energy"], "*#batch"]: + """Compute the potential energy at the given position(s).""" + q, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame + phi = pot._potential(q, t) # noqa: SLF001 + return u.Quantity(phi, pot.units["specific energy"]) + + # --------------------------- @@ -60,6 +95,7 @@ def potential( pot: AbstractPotential, tq: Any, /, *, t: Any = None ) -> Real[u.Quantity["specific energy"], "*#batch"]: """Compute from a q + t object.""" + # Note: No default time here because if not specified, use the time on the tq object return api.potential(pot, tq, t) @@ -100,7 +136,11 @@ def gradient( @dispatch # special-case Array input to not return Quantity @ft.partial(jax.jit, inline=True) def gradient( - pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0 + pot: AbstractPotential, + xyz: gt.XYZArrayLike, + /, + *, + t: gt.BBtLikeSz0 = DEFAULT_TIME.value, ) -> gt.BBtSz3: return api.gradient(pot, xyz, t) @@ -112,7 +152,11 @@ def gradient( @dispatch @ft.partial(jax.jit, inline=True) def gradient( - pot: AbstractPotential, xyz: u.AbstractQuantity, /, *, t: u.AbstractQuantity + pot: AbstractPotential, + xyz: u.AbstractQuantity, + /, + *, + t: u.AbstractQuantity = DEFAULT_TIME, ) -> Real[u.Quantity["acceleration"], "*#batch 3"]: """Compute from a q + t object.""" xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame @@ -140,6 +184,7 @@ def gradient( pot: AbstractPotential, tq: Any, /, *, t: Any = None ) -> cx.vecs.CartesianAcc3D: """Compute from a q + t object.""" + # Note: No default time here because if not specified, use the time on the tq object xyz, t = parse_to_xyz_t(None, tq, t, ustrip=pot.units, dtype=float) # TODO: frame grad = pot._gradient(xyz, t) # noqa: SLF001 return cx.vecs.CartesianAcc3D.from_(grad, pot.units["acceleration"]) @@ -180,11 +225,45 @@ def laplacian( @dispatch # special-case Array input to not return Quantity @ft.partial(jax.jit, inline=True) def laplacian( - pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0 + pot: AbstractPotential, + xyz: gt.XYZArrayLike, + /, + *, + t: gt.BBtLikeSz0 = DEFAULT_TIME.value, ) -> gt.BBtSz0: return api.laplacian(pot, xyz, t) +# --------------------------- +# Quantity + + +@dispatch +@ft.partial(jax.jit, inline=True) +def laplacian( + pot: AbstractPotential, + xyz: u.AbstractQuantity, + /, + *, + t: u.AbstractQuantity = DEFAULT_TIME, +) -> Real[u.Quantity["frequency drift"], "*#batch"]: + """Compute from a quantity object.""" + xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame + lapl = pot._laplacian(xyz, t) # noqa: SLF001 + return u.Quantity.from_(lapl, pot.units["frequency drift"]) + + +@dispatch +@ft.partial(jax.jit, inline=True) +def laplacian( + pot: AbstractPotential, q: u.AbstractQuantity, t: u.AbstractQuantity, / +) -> Real[u.Quantity["frequency drift"], "*#batch"]: + """Compute from a quantity object and a time.""" + xyz, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame + lapl = pot._laplacian(xyz, t) # noqa: SLF001 + return u.Quantity.from_(lapl, pot.units["frequency drift"]) + + # --------------------------- @@ -235,11 +314,45 @@ def density( @dispatch # special-case Array input to not return Quantity @ft.partial(jax.jit, inline=True) def density( - pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0 + pot: AbstractPotential, + xyz: gt.XYZArrayLike, + /, + *, + t: gt.BBtLikeSz0 = DEFAULT_TIME.value, ) -> gt.BBtSz0: return api.density(pot, xyz, t) +# --------------------------- +# Quantity + + +@dispatch +@ft.partial(jax.jit, inline=True) +def density( + pot: AbstractPotential, + xyz: u.AbstractQuantity, + /, + *, + t: u.AbstractQuantity = DEFAULT_TIME, +) -> Real[u.Quantity["mass density"], "*#batch"]: + """Compute from a quantity object.""" + xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame + rho = pot._density(xyz, t) # noqa: SLF001 + return u.Quantity.from_(rho, pot.units["mass density"]) + + +@dispatch +@ft.partial(jax.jit, inline=True) +def density( + pot: AbstractPotential, q: u.AbstractQuantity, t: u.AbstractQuantity, / +) -> Real[u.Quantity["mass density"], "*#batch"]: + """Compute from a quantity object and a time.""" + xyz, t = parse_to_xyz_t(None, q, t, ustrip=pot.units, dtype=float) # TODO: frame + rho = pot._density(xyz, t) # noqa: SLF001 + return u.Quantity.from_(rho, pot.units["mass density"]) + + # --------------------------- @@ -290,11 +403,49 @@ def hessian( @dispatch # special-case Array input to not return Quantity @ft.partial(jax.jit, inline=True) def hessian( - pot: AbstractPotential, xyz: gt.XYZArrayLike, /, *, t: gt.BBtLikeSz0 + pot: AbstractPotential, + xyz: gt.XYZArrayLike, + /, + *, + t: gt.BBtLikeSz0 = DEFAULT_TIME.value, ) -> gt.BBtSz33: return api.hessian(pot, xyz, t) +# --------------------------- +# Quantity + + +@dispatch +@ft.partial(jax.jit, inline=True) +def hessian( + pot: AbstractPotential, + xyz: u.AbstractQuantity, + /, + *, + t: u.AbstractQuantity = DEFAULT_TIME, +) -> Real[u.Quantity["frequency drift"], "*#batch 3 3"]: + """Compute from a quantity object.""" + xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame + phi = pot._hessian(xyz, t) # noqa: SLF001 + return u.Quantity.from_(phi, pot.units["frequency drift"]) + + +@dispatch +@ft.partial(jax.jit, inline=True) +def hessian( + pot: AbstractPotential, + xyz: u.AbstractQuantity, + /, + *, + t: u.AbstractQuantity = DEFAULT_TIME, +) -> Real[u.Quantity["frequency drift"], "*#batch 3 3"]: + """Compute from a quantity object.""" + xyz, t = parse_to_xyz_t(None, xyz, t, ustrip=pot.units, dtype=float) # TODO: frame + phi = pot._hessian(xyz, t) # noqa: SLF001 + return u.Quantity.from_(phi, pot.units["frequency drift"]) + + # --------------------------- diff --git a/src/galax/potential/_src/utils.py b/src/galax/potential/_src/utils.py index 0401dfda9..5f44e0cf2 100644 --- a/src/galax/potential/_src/utils.py +++ b/src/galax/potential/_src/utils.py @@ -19,6 +19,7 @@ import galax._custom_types as gt import galax.coordinates as gc +from galax.utils.defaults import DEFAULT_TIME OptUSys: TypeAlias = u.AbstractUnitSystem | None @@ -204,7 +205,7 @@ def parse_to_xyz_t( def parse_to_xyz_t( to_frame: cxf.AbstractReferenceFrame | None, xyz: gt.XYZArrayLike, - t: gt.BBtLikeSz0, # TODO: consider also "*#batch 1" + t: gt.BBtLikeSz0 | None, # TODO: consider also "*#batch 1" /, *, dtype: Any = None, @@ -213,7 +214,11 @@ def parse_to_xyz_t( """Parse input arguments to position & time.""" # Process the input arguments into arrays xyz = jnp.asarray(xyz, dtype=dtype) - t = jnp.asarray(t, dtype=dtype) + t = ( + jnp.asarray(t, dtype=dtype) + if t is not None + else jnp.asarray(DEFAULT_TIME.value, dtype=dtype) + ) # The coordinates are assumed to be in the simulation frame and may need to # be transformed to the target frame. @@ -263,12 +268,12 @@ def parse_to_xyz_t( @coord_dispatcher.multi( (cxf.AbstractReferenceFrame | None, gt.BBtQuSz3, gt.BBtQuSz0), - (cxf.AbstractReferenceFrame | None, gt.BBtQuSz3, gt.BBtSz0 | float | int), + (cxf.AbstractReferenceFrame | None, gt.BBtQuSz3, gt.BBtSz0 | float | int | None), ) def parse_to_xyz_t( to_frame: cxf.AbstractReferenceFrame | None, xyz: gt.BBtQorVSz3, - t: gt.BBtQorVSz0 | float | int, + t: gt.BBtQorVSz0 | float | int | None, /, *, dtype: Any = None, @@ -276,7 +281,11 @@ def parse_to_xyz_t( ) -> tuple[gt.BBtQorVSz3, gt.BBtQorVSz0]: """Parse input arguments to position & time.""" xyz = jnp.asarray(xyz, dtype=dtype) - t = jnp.asarray(t, dtype=dtype) + t = ( + jnp.asarray(t, dtype=dtype) + if t is not None + else jnp.asarray(DEFAULT_TIME, dtype=dtype) + ) if ustrip is not None: xyz = u.ustrip(AllowValue, ustrip["length"], xyz) diff --git a/src/galax/potential/_src/xfm/xop.py b/src/galax/potential/_src/xfm/xop.py index bdd71465f..b674fbdc5 100644 --- a/src/galax/potential/_src/xfm/xop.py +++ b/src/galax/potential/_src/xfm/xop.py @@ -15,6 +15,7 @@ import galax._custom_types as gt from .base import AbstractTransformedPotential from galax.potential._src.base import AbstractPotential +from galax.utils.defaults import DEFAULT_TIME @final @@ -189,7 +190,9 @@ class TransformedPotential(AbstractTransformedPotential): """ - def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: + def _potential( + self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0 = DEFAULT_TIME, / + ) -> gt.BBtSz0: """Compute the potential energy at the given position(s). This method applies the operators to the coordinates and then evaluates diff --git a/src/galax/utils/defaults.py b/src/galax/utils/defaults.py new file mode 100644 index 000000000..688f9123e --- /dev/null +++ b/src/galax/utils/defaults.py @@ -0,0 +1,8 @@ +"""Globally assumed default values across Galax.""" + +__all__ = ["DEFAULT_TIME"] + +import unxt as u + +# Default time value used when time is optional +DEFAULT_TIME = u.Quantity(0.0, "Myr") diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index 1fbefaba8..9a798a0c5 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -18,6 +18,7 @@ import galax.potential.params as gpp from .io.test_gala import GalaIOMixin from galax.potential._src.base import default_constants +from galax.utils.defaults import DEFAULT_TIME class AbstractPotential_Test(GalaIOMixin, metaclass=ABCMeta): @@ -192,6 +193,34 @@ def test_evaluate_orbit_batch(self, pot: gp.AbstractPotential, xv: gt.Sz6) -> No assert orbits.shape == (2, len(ts)) assert jnp.allclose(orbits.t, ts, atol=u.Quantity(1e-16, "Myr")) + @pytest.mark.parametrize( + "func_name", + [ + "potential", + "gradient", + "hessian", + "laplacian", + "acceleration", + "tidal_tensor", + "local_circular_velocity", + "dpotential_dr", + "d2potential_dr2", + ], + ) + def test_optional_time( + self, func_name: str, pot: gp.AbstractPotential, x: gt.QuSz3 + ): + """Test that potentials work with optional time.""" + # Test with explicit time + with_time = getattr(pot, func_name)(x, t=DEFAULT_TIME) + + # Tests with default time (should be equivalent) + default_time = getattr(pot, func_name)(x) + default_time_arr = getattr(pot, func_name)(x.value) + + assert jnp.allclose(with_time.value, default_time.value) + assert jnp.allclose(default_time.value, default_time_arr) + ##############################################################################