diff --git a/src/galax/potential/__init__.py b/src/galax/potential/__init__.py index f50951cb..d8779719 100644 --- a/src/galax/potential/__init__.py +++ b/src/galax/potential/__init__.py @@ -43,6 +43,7 @@ "TriaxialHernquistPotential", "HardCutoffNFWPotential", "gNFWPotential", + "GaussianDensityPotential", # Pre-composited "AbstractPreCompositedPotential", "BovyMWPotential2014", @@ -98,6 +99,7 @@ AbstractMultipolePotential, BovyMWPotential2014, BurkertPotential, + GaussianDensityPotential, HardCutoffNFWPotential, HarmonicOscillatorPotential, HenonHeilesPotential, diff --git a/src/galax/potential/_src/builtin/__init__.py b/src/galax/potential/_src/builtin/__init__.py index bd3108d7..3ed8e514 100644 --- a/src/galax/potential/_src/builtin/__init__.py +++ b/src/galax/potential/_src/builtin/__init__.py @@ -41,6 +41,7 @@ "StoneOstriker15Potential", "TriaxialHernquistPotential", "HardCutoffNFWPotential", + "GaussianDensityPotential", ] from .burkert import BurkertPotential @@ -54,6 +55,7 @@ HarmonicOscillatorPotential, HenonHeilesPotential, ) +from .gaussian import GaussianDensityPotential from .hernquist import HernquistPotential, TriaxialHernquistPotential from .isochrone import IsochronePotential from .jaffe import JaffePotential diff --git a/src/galax/potential/_src/builtin/gaussian.py b/src/galax/potential/_src/builtin/gaussian.py new file mode 100644 index 00000000..cabd9e05 --- /dev/null +++ b/src/galax/potential/_src/builtin/gaussian.py @@ -0,0 +1,85 @@ +"""galax: Galactic Dynamix in Jax.""" + +__all__ = ["GaussianDensityPotential"] + +import functools as ft +from dataclasses import KW_ONLY +from typing import final + +import equinox as eqx +import jax +import jax.scipy.special as jsp + +import quaxed.numpy as jnp +import unxt as u +from xmmutablemap import ImmutableMap + +import galax._custom_types as gt +from galax.potential._src.base import default_constants +from galax.potential._src.base_single import AbstractSinglePotential +from galax.potential._src.params.base import AbstractParameter +from galax.potential._src.params.field import ParameterField +from galax.potential._src.utils import r_spherical + + +@final +class GaussianDensityPotential(AbstractSinglePotential): + r"""Potential of a spherical Gaussian density profile. + + The gravitational potential corresponding to a spherical Gaussian density profile: + + .. math:: + + \rho(r) = \frac{M}{(2 \pi)^{3/2} \, r_s^3}\exp\left(-\frac{r^2}{2 r_s^2}\right) + + """ + + m_tot: AbstractParameter = ParameterField( # type: ignore[assignment] + dimensions="mass", doc="Total mass of the potential." + ) + + r_s: AbstractParameter = ParameterField( # type: ignore[assignment] + dimensions="length", doc="Scale radius (standard deviation of the Gaussian)." + ) + + _: KW_ONLY + units: u.AbstractUnitSystem = eqx.field(converter=u.unitsystem, static=True) + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) + + @ft.partial(jax.jit) + def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + params = { + "G": self.constants["G"].value, + "m_tot": self.m_tot(t, ustrip=self.units["mass"]), + "r_s": self.r_s(t, ustrip=self.units["length"]), + } + return potential(params, r) + + @ft.partial(jax.jit) + def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + params = { + "m_tot": self.m_tot(t, ustrip=self.units["mass"]), + "r_s": self.r_s(t, ustrip=self.units["length"]), + } + return density(params, r) + + +@ft.partial(jax.jit) +def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: + r"""Gaussian density profile.""" + rho0 = p["m_tot"] / ((2 * jnp.pi) ** (3 / 2) * p["r_s"] ** 3) + return rho0 * jnp.exp(-(r**2) / (2 * p["r_s"] ** 2)) + + +@ft.partial(jax.jit) +def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: + r"""Potential corresponding to a spherical Gaussian density profile in 3D.""" + return -p["G"] * p["m_tot"] / r * jsp.erf(r / (jnp.sqrt(2) * p["r_s"])) diff --git a/src/galax/potential/_src/builtin/hernquist.py b/src/galax/potential/_src/builtin/hernquist.py index d8340586..90a35e86 100644 --- a/src/galax/potential/_src/builtin/hernquist.py +++ b/src/galax/potential/_src/builtin/hernquist.py @@ -160,7 +160,7 @@ def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: @ft.partial(jax.jit) def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: - r"""Density profile for the Kepler potential.""" + r"""Density profile for the Hernquist potential.""" s = r / p["r_s"] rho0 = p["m_tot"] / (2 * jnp.pi * p["r_s"] ** 3) return rho0 / (s * (1 + s) ** 3) diff --git a/tests/unit/potential/builtin/test_gaussian.py b/tests/unit/potential/builtin/test_gaussian.py new file mode 100644 index 00000000..719b147b --- /dev/null +++ b/tests/unit/potential/builtin/test_gaussian.py @@ -0,0 +1,77 @@ +from typing import Any + +import pytest + +import quaxed.numpy as jnp +import unxt as u + +import galax._custom_types as gt +import galax.potential as gp +from ..test_core import AbstractSinglePotential_Test +from .test_common import ParameterMTotMixin, ParameterRSMixin + + +class TestGaussianDensityPotential( + AbstractSinglePotential_Test, + # Parameters + ParameterMTotMixin, + ParameterRSMixin, +): + @pytest.fixture(scope="class") + def pot_cls(self) -> type[gp.GaussianDensityPotential]: + return gp.GaussianDensityPotential + + @pytest.fixture(scope="class") + def fields_(self, field_m_tot, field_r_s, field_units) -> dict[str, Any]: + return {"m_tot": field_m_tot, "r_s": field_r_s, "units": field_units} + + # ========================================================================== + + def test_potential(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(-1.20205548, pot.units["specific energy"]) + assert jnp.isclose( + pot.potential(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: + expect = u.Quantity( + [0.08562732, 0.17125464, 0.25688196], pot.units["acceleration"] + ) + got = pot.gradient(x, t=0) + assert jnp.allclose(got, expect, atol=u.Quantity(1e-8, expect.unit)) + + def test_density(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(57898701.53591853, pot.units["mass density"]) + assert jnp.isclose( + pot.density(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: + expect = u.Quantity( + [ + [0.06751239, -0.03622985, -0.05434478], + [-0.03622985, 0.01316762, -0.10868955], + [-0.05434478, -0.10868955, -0.07740701], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.hessian(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotential, x: gt.QuSz3) -> None: + """Test the `AbstractPotential.tidal_tensor` method.""" + expect = u.Quantity( + [ + [0.06642139, -0.03622985, -0.05434478], + [-0.03622985, 0.01207662, -0.10868955], + [-0.05434478, -0.10868955, -0.07849801], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + )