From 02643e471af625c44e13c3be0464eb125e1da1ac Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:07:39 -0400 Subject: [PATCH 1/5] fix typo in hernquist --- src/galax/potential/_src/builtin/hernquist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/galax/potential/_src/builtin/hernquist.py b/src/galax/potential/_src/builtin/hernquist.py index d8340586..90a35e86 100644 --- a/src/galax/potential/_src/builtin/hernquist.py +++ b/src/galax/potential/_src/builtin/hernquist.py @@ -160,7 +160,7 @@ def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: @ft.partial(jax.jit) def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: - r"""Density profile for the Kepler potential.""" + r"""Density profile for the Hernquist potential.""" s = r / p["r_s"] rho0 = p["m_tot"] / (2 * jnp.pi * p["r_s"] ** 3) return rho0 / (s * (1 + s) ** 3) From ca610c46b267b69efa8f212727629ca37e48305f Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:08:20 -0400 Subject: [PATCH 2/5] add implementation of potential from a Gaussian density profile --- src/galax/potential/_src/builtin/gaussian.py | 85 ++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 src/galax/potential/_src/builtin/gaussian.py diff --git a/src/galax/potential/_src/builtin/gaussian.py b/src/galax/potential/_src/builtin/gaussian.py new file mode 100644 index 00000000..42fe9f59 --- /dev/null +++ b/src/galax/potential/_src/builtin/gaussian.py @@ -0,0 +1,85 @@ +"""galax: Galactic Dynamix in Jax.""" + +__all__ = ["GaussianPotential"] + +import functools as ft +from dataclasses import KW_ONLY +from typing import final + +import equinox as eqx +import jax +import jax.scipy.special as jsp + +import quaxed.numpy as jnp +import unxt as u +from xmmutablemap import ImmutableMap + +import galax._custom_types as gt +from galax.potential._src.base import default_constants +from galax.potential._src.base_single import AbstractSinglePotential +from galax.potential._src.params.base import AbstractParameter +from galax.potential._src.params.field import ParameterField +from galax.potential._src.utils import r_spherical + + +@final +class GaussianPotential(AbstractSinglePotential): + r"""A spherical Gaussian potential. + + The gravitational potential corresponding to a spherical Gaussian density profile: + + .. math:: + + \rho(r) = \frac{M}{(2 \pi)^{3/2} \, r_s^3}\exp\left(-\frac{r^2}{2 r_s^2}\right) + + """ + + m_tot: AbstractParameter = ParameterField( # type: ignore[assignment] + dimensions="mass", doc="Total mass of the potential." + ) + + r_s: AbstractParameter = ParameterField( # type: ignore[assignment] + dimensions="length", doc="Scale radius (standard deviation of the Gaussian)." + ) + + _: KW_ONLY + units: u.AbstractUnitSystem = eqx.field(converter=u.unitsystem, static=True) + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) + + @ft.partial(jax.jit) + def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + params = { + "G": self.constants["G"].value, + "m_tot": self.m_tot(t, ustrip=self.units["mass"]), + "r_s": self.r_s(t, ustrip=self.units["length"]), + } + return potential(params, r) + + @ft.partial(jax.jit) + def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + params = { + "m_tot": self.m_tot(t, ustrip=self.units["mass"]), + "r_s": self.r_s(t, ustrip=self.units["length"]), + } + return density(params, r) + + +@ft.partial(jax.jit) +def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: + r"""Gaussian density profile.""" + rho0 = p["m_tot"] / ((2 * jnp.pi) ** (3 / 2) * p["r_s"] ** 3) + return rho0 * jnp.exp(-(r**2) / (2 * p["r_s"] ** 2)) + + +@ft.partial(jax.jit) +def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: + r"""Potential corresponding to a spherical Gaussian density profile in 3D.""" + return -p["G"] * p["m_tot"] / r * jsp.erf(r / (jnp.sqrt(2) * p["r_s"])) From e520a45545563c40544f2b172d5d0181b8ac5817 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:19:55 -0400 Subject: [PATCH 3/5] expose GaussianPotential to public API --- src/galax/potential/__init__.py | 2 ++ src/galax/potential/_src/builtin/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/galax/potential/__init__.py b/src/galax/potential/__init__.py index f50951cb..91669a10 100644 --- a/src/galax/potential/__init__.py +++ b/src/galax/potential/__init__.py @@ -43,6 +43,7 @@ "TriaxialHernquistPotential", "HardCutoffNFWPotential", "gNFWPotential", + "GaussianPotential", # Pre-composited "AbstractPreCompositedPotential", "BovyMWPotential2014", @@ -98,6 +99,7 @@ AbstractMultipolePotential, BovyMWPotential2014, BurkertPotential, + GaussianPotential, HardCutoffNFWPotential, HarmonicOscillatorPotential, HenonHeilesPotential, diff --git a/src/galax/potential/_src/builtin/__init__.py b/src/galax/potential/_src/builtin/__init__.py index bd3108d7..c756cd2b 100644 --- a/src/galax/potential/_src/builtin/__init__.py +++ b/src/galax/potential/_src/builtin/__init__.py @@ -41,6 +41,7 @@ "StoneOstriker15Potential", "TriaxialHernquistPotential", "HardCutoffNFWPotential", + "GaussianPotential", ] from .burkert import BurkertPotential @@ -54,6 +55,7 @@ HarmonicOscillatorPotential, HenonHeilesPotential, ) +from .gaussian import GaussianPotential from .hernquist import HernquistPotential, TriaxialHernquistPotential from .isochrone import IsochronePotential from .jaffe import JaffePotential From a22a680dc0eee1e37a3cce690d62a660ec03590b Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:20:02 -0400 Subject: [PATCH 4/5] add unit tests --- tests/unit/potential/builtin/test_gaussian.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 tests/unit/potential/builtin/test_gaussian.py diff --git a/tests/unit/potential/builtin/test_gaussian.py b/tests/unit/potential/builtin/test_gaussian.py new file mode 100644 index 00000000..8d1b6b01 --- /dev/null +++ b/tests/unit/potential/builtin/test_gaussian.py @@ -0,0 +1,77 @@ +from typing import Any + +import pytest + +import quaxed.numpy as jnp +import unxt as u + +import galax._custom_types as gt +import galax.potential as gp +from ..test_core import AbstractSinglePotential_Test +from .test_common import ParameterMTotMixin, ParameterRSMixin + + +class TestGaussianPotential( + AbstractSinglePotential_Test, + # Parameters + ParameterMTotMixin, + ParameterRSMixin, +): + @pytest.fixture(scope="class") + def pot_cls(self) -> type[gp.GaussianPotential]: + return gp.GaussianPotential + + @pytest.fixture(scope="class") + def fields_(self, field_m_tot, field_r_s, field_units) -> dict[str, Any]: + return {"m_tot": field_m_tot, "r_s": field_r_s, "units": field_units} + + # ========================================================================== + + def test_potential(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(-1.20205548, pot.units["specific energy"]) + assert jnp.isclose( + pot.potential(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + expect = u.Quantity( + [0.08562732, 0.17125464, 0.25688196], pot.units["acceleration"] + ) + got = pot.gradient(x, t=0) + assert jnp.allclose(got, expect, atol=u.Quantity(1e-8, expect.unit)) + + def test_density(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(57898701.53591853, pot.units["mass density"]) + assert jnp.isclose( + pot.density(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + expect = u.Quantity( + [ + [0.06751239, -0.03622985, -0.05434478], + [-0.03622985, 0.01316762, -0.10868955], + [-0.05434478, -0.10868955, -0.07740701], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.hessian(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotential, x: gt.QuSz3) -> None: + """Test the `AbstractPotential.tidal_tensor` method.""" + expect = u.Quantity( + [ + [0.06642139, -0.03622985, -0.05434478], + [-0.03622985, 0.01207662, -0.10868955], + [-0.05434478, -0.10868955, -0.07849801], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) From 1db1c762924b2f1115ea80c57dde104660d06e60 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Wed, 13 Aug 2025 22:14:15 -0400 Subject: [PATCH 5/5] change name to GaussianDensityPotential --- src/galax/potential/__init__.py | 4 ++-- src/galax/potential/_src/builtin/__init__.py | 4 ++-- src/galax/potential/_src/builtin/gaussian.py | 6 +++--- tests/unit/potential/builtin/test_gaussian.py | 14 +++++++------- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/galax/potential/__init__.py b/src/galax/potential/__init__.py index 91669a10..d8779719 100644 --- a/src/galax/potential/__init__.py +++ b/src/galax/potential/__init__.py @@ -43,7 +43,7 @@ "TriaxialHernquistPotential", "HardCutoffNFWPotential", "gNFWPotential", - "GaussianPotential", + "GaussianDensityPotential", # Pre-composited "AbstractPreCompositedPotential", "BovyMWPotential2014", @@ -99,7 +99,7 @@ AbstractMultipolePotential, BovyMWPotential2014, BurkertPotential, - GaussianPotential, + GaussianDensityPotential, HardCutoffNFWPotential, HarmonicOscillatorPotential, HenonHeilesPotential, diff --git a/src/galax/potential/_src/builtin/__init__.py b/src/galax/potential/_src/builtin/__init__.py index c756cd2b..3ed8e514 100644 --- a/src/galax/potential/_src/builtin/__init__.py +++ b/src/galax/potential/_src/builtin/__init__.py @@ -41,7 +41,7 @@ "StoneOstriker15Potential", "TriaxialHernquistPotential", "HardCutoffNFWPotential", - "GaussianPotential", + "GaussianDensityPotential", ] from .burkert import BurkertPotential @@ -55,7 +55,7 @@ HarmonicOscillatorPotential, HenonHeilesPotential, ) -from .gaussian import GaussianPotential +from .gaussian import GaussianDensityPotential from .hernquist import HernquistPotential, TriaxialHernquistPotential from .isochrone import IsochronePotential from .jaffe import JaffePotential diff --git a/src/galax/potential/_src/builtin/gaussian.py b/src/galax/potential/_src/builtin/gaussian.py index 42fe9f59..cabd9e05 100644 --- a/src/galax/potential/_src/builtin/gaussian.py +++ b/src/galax/potential/_src/builtin/gaussian.py @@ -1,6 +1,6 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = ["GaussianPotential"] +__all__ = ["GaussianDensityPotential"] import functools as ft from dataclasses import KW_ONLY @@ -23,8 +23,8 @@ @final -class GaussianPotential(AbstractSinglePotential): - r"""A spherical Gaussian potential. +class GaussianDensityPotential(AbstractSinglePotential): + r"""Potential of a spherical Gaussian density profile. The gravitational potential corresponding to a spherical Gaussian density profile: diff --git a/tests/unit/potential/builtin/test_gaussian.py b/tests/unit/potential/builtin/test_gaussian.py index 8d1b6b01..719b147b 100644 --- a/tests/unit/potential/builtin/test_gaussian.py +++ b/tests/unit/potential/builtin/test_gaussian.py @@ -11,15 +11,15 @@ from .test_common import ParameterMTotMixin, ParameterRSMixin -class TestGaussianPotential( +class TestGaussianDensityPotential( AbstractSinglePotential_Test, # Parameters ParameterMTotMixin, ParameterRSMixin, ): @pytest.fixture(scope="class") - def pot_cls(self) -> type[gp.GaussianPotential]: - return gp.GaussianPotential + def pot_cls(self) -> type[gp.GaussianDensityPotential]: + return gp.GaussianDensityPotential @pytest.fixture(scope="class") def fields_(self, field_m_tot, field_r_s, field_units) -> dict[str, Any]: @@ -27,26 +27,26 @@ def fields_(self, field_m_tot, field_r_s, field_units) -> dict[str, Any]: # ========================================================================== - def test_potential(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + def test_potential(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: expect = u.Quantity(-1.20205548, pot.units["specific energy"]) assert jnp.isclose( pot.potential(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) ) - def test_gradient(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + def test_gradient(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: expect = u.Quantity( [0.08562732, 0.17125464, 0.25688196], pot.units["acceleration"] ) got = pot.gradient(x, t=0) assert jnp.allclose(got, expect, atol=u.Quantity(1e-8, expect.unit)) - def test_density(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + def test_density(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: expect = u.Quantity(57898701.53591853, pot.units["mass density"]) assert jnp.isclose( pot.density(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) ) - def test_hessian(self, pot: gp.GaussianPotential, x: gt.QuSz3) -> None: + def test_hessian(self, pot: gp.GaussianDensityPotential, x: gt.QuSz3) -> None: expect = u.Quantity( [ [0.06751239, -0.03622985, -0.05434478],