From c0dc2708b5581d6676531a61d612a0c6ef9cde5e Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:22:32 -0400 Subject: [PATCH 1/8] initial implementation of Zhao model --- src/galax/potential/_src/builtin/zhao.py | 157 ++++++++++++++++++ .../functional/potential/builtin/test_zhao.py | 49 ++++++ 2 files changed, 206 insertions(+) create mode 100644 src/galax/potential/_src/builtin/zhao.py create mode 100644 tests/functional/potential/builtin/test_zhao.py diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py new file mode 100644 index 000000000..fa35edcd1 --- /dev/null +++ b/src/galax/potential/_src/builtin/zhao.py @@ -0,0 +1,157 @@ +__all__ = ["ZhaoPotential"] + +import functools as ft +from dataclasses import KW_ONLY +from typing import final + +import equinox as eqx +import jax +import jax.scipy.special as jsp + +import quaxed.numpy as jnp +import unxt as u +from xmmutablemap import ImmutableMap + +import galax._custom_types as gt +from galax.potential._src.base import default_constants +from galax.potential._src.base_single import AbstractSinglePotential +from galax.potential._src.params.base import AbstractParameter +from galax.potential._src.params.field import ParameterField +from galax.potential._src.utils import r_spherical + + +@final +class ZhaoPotential(AbstractSinglePotential): + r"""Zhao (1996) double power-law potential. + + This model represents a double power law in the density, with an inner slope + :math:`\gamma` and an outer slope :math:`\beta`, but with a third parameter + :math:`\alpha` that controls the width of the transition region between the two + power laws. + + This model has a finite total mass for :math:`\beta > 3`. The other power-law + parameters should satisfy :math:`\alpha > 0` and :math:`0 \leq \gamma < 3`. + """ + + m: AbstractParameter = ParameterField( + dimensions="mass", + doc=( + "Mass parameter. When beta > 3, this is the total mass. When beta <= 3, " + "this is a scale mass." + ), + ) # type: ignore[assignment] + r_s: AbstractParameter = ParameterField(dimensions="length", doc="Scale radius.") # type: ignore[assignment] + + alpha: AbstractParameter = ParameterField( + dimensions="dimensionless", doc="Transition width (alpha > 0)." + ) # type: ignore[assignment] + beta: AbstractParameter = ParameterField( + dimensions="dimensionless", doc="Outer slope (finite mass when beta > 3)." + ) # type: ignore[assignment] + gamma: AbstractParameter = ParameterField( + dimensions="dimensionless", doc="Inner slope (0 <= gamma < 3)." + ) # type: ignore[assignment] + + _: KW_ONLY + units: u.AbstractUnitSystem = eqx.field(converter=u.unitsystem, static=True) + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) + + @ft.partial(jax.jit) + def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + ulen = self.units["length"] + umass = self.units["mass"] + udim = self.units["dimensionless"] + p = { + "G": self.constants["G"].value, + "m": self.m(t, ustrip=umass), + "r_s": self.r_s(t, ustrip=ulen), + "alpha": self.alpha(t, ustrip=udim), + "beta": self.beta(t, ustrip=udim), + "gamma": self.gamma(t, ustrip=udim), + } + return potential(p, r) + + @ft.partial(jax.jit) + def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + ulen = self.units["length"] + umass = self.units["mass"] + udim = self.units["dimensionless"] + p = { + "m": self.m(t, ustrip=umass), + "r_s": self.r_s(t, ustrip=ulen), + "alpha": self.alpha(t, ustrip=udim), + "beta": self.beta(t, ustrip=udim), + "gamma": self.gamma(t, ustrip=udim), + } + return density(p, r) + + +@ft.partial(jax.jit) +def _rho0(m: gt.Sz0, alpha: gt.Sz0, beta: gt.Sz0, gamma: gt.Sz0, /) -> gt.Sz0: + r"""Scale density. + + Define the scale density (called "C" in the Zhao paper) such that the integral: + + \int_0^\infty 4\pi r^2 \rho(r) \, dr = m + + for models with finite mass. + + See Eq. 44 in Zhao (1996). + """ + denom = alpha * jsp.beta(alpha * (3.0 - gamma), alpha * (beta - 3.0)) + return m / (4.0 * jnp.pi * denom) + + +@ft.partial(jax.jit) +def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: + """Spherical density profile for double power-law Zhao model.""" + x = r / p["r_s"] + alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"] + rho0 = _rho0(p["m"], alpha, beta, gamma) + + b = (beta - gamma) * alpha + return (rho0 / (p["r_s"] ** 3)) / x**gamma / (1.0 + x ** (1.0 / alpha)) ** b + + +@ft.partial(jax.jit) +def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: + r"""Spherical potential for double power-law Zhao model. + + See Eq. 6 and 7 in Zhao (1996). + + This function uses the variable z for what Zhao called :math:`\chi`. + """ + x = r / p["r_s"] + alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"] + + # What Zhao calls "chi": + z = x ** (1.0 / alpha) / (1.0 + x ** (1.0 / alpha)) + + rho0 = _rho0(p["m"], alpha, beta, gamma) + + # Constants defined in appendix of Zhao (1996) + p0 = alpha * (2.0 - gamma) + q0 = alpha * (beta - 3.0) + c0 = alpha * (beta - gamma) + + # Left term in Eq. 7 + term_l = jsp.beta(c0 - q0, q0) * jsp.betainc(c0 - q0, q0, z) + + # Right term in Eq. 7 + # This uses a trick (from chatgpt) for avoiding the pole as p0 -> 0 (gamma -> 2) + # Trick: compute as exp(betaln + log(1 - I_z)) + eps = jnp.sqrt(jnp.finfo(r.dtype).eps) + p0_safe = jnp.where(p0 <= 0, eps, p0) + logB = jsp.betaln(p0_safe, c0 - p0) + log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, z)) + term_r = jnp.exp(logB + log1mI) + + return -4.0 * jnp.pi * p["G"] * rho0 * alpha * (term_l / r + term_r / p["r_s"]) diff --git a/tests/functional/potential/builtin/test_zhao.py b/tests/functional/potential/builtin/test_zhao.py new file mode 100644 index 000000000..32fbe9c50 --- /dev/null +++ b/tests/functional/potential/builtin/test_zhao.py @@ -0,0 +1,49 @@ +import jax +import pytest + +import quaxed.numpy as jnp +import unxt as u + +import galax.potential as gp +from galax.potential._src.builtin.zhao import ZhaoPotential + +check_funcs = ["potential", "gradient", "density", "hessian"] + +# settings of (alpha, beta, gamma) and correspondence to known models: +abg_pot = { + (1, 4, 1): gp.HernquistPotential, + (1, 4, 2): gp.JaffePotential, + (1 / 2, 5, 0): gp.PlummerPotential, + # (1, 3, 1): gp.NFWPotential, +} + + +@pytest.fixture +def xyz(): + test_r = jnp.geomspace(1e-3, 1e2, 128) + rand_uvecs = jax.random.normal(jax.random.key(42), shape=(test_r.size, 3)) + rand_uvecs = rand_uvecs / jnp.linalg.norm(rand_uvecs, axis=-1, keepdims=True) + return u.Quantity(test_r[:, None] * rand_uvecs, "kpc") + + +@pytest.mark.parametrize("func_name", check_funcs) +@pytest.mark.parametrize(("abg", "OtherPotential"), list(abg_pot.items())) +def test_zhao_against_correspondences(func_name, abg, OtherPotential, xyz): + zhao = ZhaoPotential( + m=u.Quantity(1.3e11, "Msun"), + r_s=u.Quantity(8.1, "kpc"), + alpha=abg[0], + beta=abg[1], + gamma=abg[2], + units="galactic", + ) + other = OtherPotential( + m_tot=zhao.parameters["m"], r_s=zhao.parameters["r_s"], units="galactic" + ) + + zhao_result = getattr(zhao, func_name)(xyz, u.Quantity(0.0, "Myr")) + other_result = getattr(other, func_name)(xyz, u.Quantity(0.0, "Myr")) + + assert jnp.allclose( + zhao_result, other_result, rtol=1e-8, atol=u.Quantity(1e-6, zhao_result.unit) + ) From c13ef5d6286b80a597103cecc2e5cb14508a5035 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:27:42 -0400 Subject: [PATCH 2/8] add note of potential correspondences --- src/galax/potential/_src/builtin/zhao.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py index fa35edcd1..8e14f1af8 100644 --- a/src/galax/potential/_src/builtin/zhao.py +++ b/src/galax/potential/_src/builtin/zhao.py @@ -31,6 +31,14 @@ class ZhaoPotential(AbstractSinglePotential): This model has a finite total mass for :math:`\beta > 3`. The other power-law parameters should satisfy :math:`\alpha > 0` and :math:`0 \leq \gamma < 3`. + + This model also reduces to a number of well-known analytic forms for certain values + of the parameters (reproduced from Table 1 of Zhao 1996): + - :math:`(\alpha, \beta, \gamma) = (1, 4, 1)`: Hernquist model + - :math:`(\alpha, \beta, \gamma) = (1, 4, 2)`: Jaffe model + - :math:`(\alpha, \beta, \gamma) = (1/2, 5, 0)`: Plummer model + - :math:`(\alpha, \beta, \gamma) = (1, 3, 1)`: NFW model + - :math:`(\alpha, \beta, \gamma) = (1, 3, \gamma)`: Generalized NFW model """ m: AbstractParameter = ParameterField( From 8824769242df3400a16c16d6b5e4aab88b9e8a5c Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Wed, 13 Aug 2025 22:04:22 -0400 Subject: [PATCH 3/8] some refactor and add a special case for Jaffe to avoid bad gradient --- src/galax/potential/_src/builtin/zhao.py | 55 ++++++++++++++++++------ 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py index 8e14f1af8..438b73922 100644 --- a/src/galax/potential/_src/builtin/zhao.py +++ b/src/galax/potential/_src/builtin/zhao.py @@ -118,6 +118,22 @@ def _rho0(m: gt.Sz0, alpha: gt.Sz0, beta: gt.Sz0, gamma: gt.Sz0, /) -> gt.Sz0: return m / (4.0 * jnp.pi * denom) +@ft.partial(jax.jit) +def _cpq(a: gt.Sz0, b: gt.Sz0, g: gt.Sz0) -> tuple[gt.Sz0, gt.Sz0, gt.Sz0]: + """Constants defined in appendix of Zhao (1996).""" + c0 = a * (b - g) + p0 = a * (2.0 - g) + q0 = a * (b - 3.0) + return c0, p0, q0 + + +@ft.partial(jax.jit) +def _r_to_xz(p: gt.Params, r: gt.Sz0) -> gt.FloatSz0: + """Convert radius to the variable z used in the potential.""" + x = r / p["r_s"] + return x, x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"])) + + @ft.partial(jax.jit) def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: """Spherical density profile for double power-law Zhao model.""" @@ -129,6 +145,15 @@ def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: return (rho0 / (p["r_s"] ** 3)) / x**gamma / (1.0 + x ** (1.0 / alpha)) ** b +@ft.partial(jax.jit) +def mass_enclosed(p: gt.Params, z: gt.Sz0) -> gt.Sz0: + a, b, g = p["alpha"], p["beta"], p["gamma"] + rho0 = _rho0(p["m"], a, b, g) + *_, q0 = _cpq(a, b, g) + aM = a * (3.0 - g) + return 4.0 * jnp.pi * rho0 * a * (jsp.beta(aM, q0) * jsp.betainc(aM, q0, z)) + + @ft.partial(jax.jit) def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: r"""Spherical potential for double power-law Zhao model. @@ -137,29 +162,31 @@ def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: This function uses the variable z for what Zhao called :math:`\chi`. """ - x = r / p["r_s"] - alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"] + a, b, g = p["alpha"], p["beta"], p["gamma"] - # What Zhao calls "chi": - z = x ** (1.0 / alpha) / (1.0 + x ** (1.0 / alpha)) + # z here is what Zhao calls "chi": + x, z = _r_to_xz(p, r) - rho0 = _rho0(p["m"], alpha, beta, gamma) + # Special case the Jaffe potential, where there is an "inf - inf" below + is_jaffe = (a == 1.0) & (b == 4.0) & (g == 2.0) + + def Phi_jaffe() -> gt.Sz0: + return -p["G"] * p["m"] / p["r_s"] * jnp.log1p(1.0 / x) - # Constants defined in appendix of Zhao (1996) - p0 = alpha * (2.0 - gamma) - q0 = alpha * (beta - 3.0) - c0 = alpha * (beta - gamma) + rho0 = _rho0(p["m"], a, b, g) + c0, p0, _ = _cpq(a, b, g) # Left term in Eq. 7 - term_l = jsp.beta(c0 - q0, q0) * jsp.betainc(c0 - q0, q0, z) + term_l = mass_enclosed(p, z) # Right term in Eq. 7 - # This uses a trick (from chatgpt) for avoiding the pole as p0 -> 0 (gamma -> 2) - # Trick: compute as exp(betaln + log(1 - I_z)) eps = jnp.sqrt(jnp.finfo(r.dtype).eps) p0_safe = jnp.where(p0 <= 0, eps, p0) logB = jsp.betaln(p0_safe, c0 - p0) log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, z)) - term_r = jnp.exp(logB + log1mI) + term_r = 4.0 * jnp.pi * rho0 * a / p["r_s"] * jnp.exp(logB + log1mI) + + def Phi_general() -> gt.Sz0: + return -p["G"] * (term_l / r + term_r) - return -4.0 * jnp.pi * p["G"] * rho0 * alpha * (term_l / r + term_r / p["r_s"]) + return jax.lax.cond(is_jaffe, Phi_jaffe, Phi_general) From 578e53598e0c7d4029527e048b43249fad8f2841 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Wed, 13 Aug 2025 22:13:08 -0400 Subject: [PATCH 4/8] add test module for Zhao potential --- tests/unit/potential/builtin/test_zhao.py | 178 ++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 tests/unit/potential/builtin/test_zhao.py diff --git a/tests/unit/potential/builtin/test_zhao.py b/tests/unit/potential/builtin/test_zhao.py new file mode 100644 index 000000000..0708a008e --- /dev/null +++ b/tests/unit/potential/builtin/test_zhao.py @@ -0,0 +1,178 @@ +from typing import Any, ClassVar + +import pytest + +import quaxed.numpy as jnp +import unxt as u + +import galax._custom_types as gt +import galax.potential as gp +from ..param.test_field import ParameterFieldMixin +from ..test_core import AbstractSinglePotential_Test +from .test_common import ParameterMMixin, ParameterRSMixin +from galax.potential._src.builtin.zhao import ZhaoPotential + + +class AlphaParameterMixin(ParameterFieldMixin): + """Test the alpha parameter.""" + + @pytest.fixture(scope="class") + def field_alpha(self) -> u.Quantity["dimensionless"]: + return u.Quantity(0.9, "") + + # ===================================================== + + def test_alpha_constant(self, pot_cls, fields): + """Test the `alpha` parameter.""" + fields["alpha"] = u.Quantity(1.0, "") + pot = pot_cls(**fields) + assert pot.alpha(t=u.Quantity(0, "Myr")) == u.Quantity(1.0, "") + + def test_alpha_userfunc(self, pot_cls, fields): + """Test the `alpha` parameter.""" + + def cos_alpha(t: u.Quantity["time"]) -> u.Quantity[""]: + return u.Quantity(0.5 * jnp.cos(t.ustrip("Myr")) ** 2 + 0.5, "") + + fields["alpha"] = cos_alpha + pot = pot_cls(**fields) + assert pot.alpha(t=u.Quantity(0, "Myr")) == u.Quantity(1.0, "") + + +class BetaParameterMixin(ParameterFieldMixin): + """Test the beta parameter.""" + + @pytest.fixture(scope="class") + def field_beta(self) -> u.Quantity["dimensionless"]: + return u.Quantity(4.31, "") + + # ===================================================== + + def test_beta_constant(self, pot_cls, fields): + """Test the `beta` parameter.""" + fields["beta"] = u.Quantity(3.5, "") + pot = pot_cls(**fields) + assert pot.beta(t=u.Quantity(0, "Myr")) == u.Quantity(3.5, "") + + def test_beta_userfunc(self, pot_cls, fields): + """Test the `beta` parameter.""" + + def cos_beta(t: u.Quantity["time"]) -> u.Quantity[""]: + return u.Quantity(jnp.cos(t.ustrip("Myr")) + 4.2, "") + + fields["beta"] = cos_beta + pot = pot_cls(**fields) + assert pot.beta(t=u.Quantity(0, "Myr")) == u.Quantity(5.2, "") + + +class GammaParameterMixin(ParameterFieldMixin): + """Test the gamma parameter.""" + + @pytest.fixture(scope="class") + def field_gamma(self) -> u.Quantity["dimensionless"]: + return u.Quantity(1.2, "") + + # ===================================================== + + def test_gamma_constant(self, pot_cls, fields): + """Test the `gamma` parameter.""" + fields["gamma"] = u.Quantity(1.5, "") + pot = pot_cls(**fields) + assert pot.gamma(t=u.Quantity(0, "Myr")) == u.Quantity(1.5, "") + + def test_gamma_userfunc(self, pot_cls, fields): + """Test the `gamma` parameter.""" + + def cos_gamma(t: u.Quantity["time"]) -> u.Quantity[""]: + return u.Quantity(0.5 * jnp.cos(t.ustrip("Myr")) + 1.5, "") + + fields["gamma"] = cos_gamma + pot = pot_cls(**fields) + assert pot.gamma(t=u.Quantity(0, "Myr")) == u.Quantity(2.0, "") + + +class TestZhaoPotential( + AbstractSinglePotential_Test, + # Parameters + ParameterMMixin, + ParameterRSMixin, + AlphaParameterMixin, + BetaParameterMixin, + GammaParameterMixin, +): + """Test the `galax.potential.ZhaoPotential` class.""" + + HAS_GALA_COUNTERPART: ClassVar[bool] = False + + @pytest.fixture(scope="class") + def pot_cls(self) -> type[ZhaoPotential]: + return ZhaoPotential + + @pytest.fixture(scope="class") + def fields_( + self, + field_m: u.Quantity, + field_r_s: u.Quantity, + field_alpha: u.Quantity, + field_beta: u.Quantity, + field_gamma: u.Quantity, + field_units: u.AbstractUnitSystem, + ) -> dict[str, Any]: + return { + "m": field_m, + "r_s": field_r_s, + "alpha": field_alpha, + "beta": field_beta, + "gamma": field_gamma, + "units": field_units, + } + + # ========================================================================== + + def test_potential(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(-1.07419698, unit="kpc2 / Myr2") + print(pot.parameters) + assert jnp.isclose( + pot.potential(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity([0.06671604, 0.13343208, 0.20014812], "kpc / Myr2") + got = pot.gradient(x, t=0) + assert jnp.allclose(got, expect, atol=u.Quantity(1e-8, expect.unit)) + + def test_density(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(3.39060592e08, "solMass / kpc3") + assert jnp.isclose( + pot.density(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity( + [ + [0.05378882, -0.02585444, -0.03878166], + [-0.02585444, 0.01500716, -0.07756332], + [-0.03878166, -0.07756332, -0.04962894], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.hessian(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotential, x: gt.QuSz3) -> None: + """Test the `AbstractPotential.tidal_tensor` method.""" + expect = u.Quantity( + [ + [0.04739981, -0.02585444, -0.03878166], + [-0.02585444, 0.00861815, -0.07756332], + [-0.03878166, -0.07756332, -0.05601795], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) From a4ad0bff45159e85179e09eb0e811a3a1ab40636 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Thu, 14 Aug 2025 16:12:25 -0400 Subject: [PATCH 5/8] switch interpretation of mass parameter to be mass enclosed within scale radius --- src/galax/potential/_src/builtin/zhao.py | 95 ++++++++++++++++-------- 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py index 438b73922..7e2a285ab 100644 --- a/src/galax/potential/_src/builtin/zhao.py +++ b/src/galax/potential/_src/builtin/zhao.py @@ -44,8 +44,9 @@ class ZhaoPotential(AbstractSinglePotential): m: AbstractParameter = ParameterField( dimensions="mass", doc=( - "Mass parameter. When beta > 3, this is the total mass. When beta <= 3, " - "this is a scale mass." + "Scale mass parameter. This is equivalent to the mass enclosed within the " + "scale radius. When beta > 3, the model has finite mass, but when beta <= 3" + " the total mass is infinite." ), ) # type: ignore[assignment] r_s: AbstractParameter = ParameterField(dimensions="length", doc="Scale radius.") # type: ignore[assignment] @@ -102,20 +103,41 @@ def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: return density(p, r) -@ft.partial(jax.jit) -def _rho0(m: gt.Sz0, alpha: gt.Sz0, beta: gt.Sz0, gamma: gt.Sz0, /) -> gt.Sz0: - r"""Scale density. - - Define the scale density (called "C" in the Zhao paper) such that the integral: - \int_0^\infty 4\pi r^2 \rho(r) \, dr = m +@ft.partial(jax.jit) +def _total_mass_factor(p: gt.Params, r_ref: gt.Sz0) -> gt.FloatSz0: + """Compute the total mass factor for the Zhao profile.""" + c0, _, q0 = _cpq(p["alpha"], p["beta"], p["gamma"]) + x = r_ref / p["r_s"] + chi = x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"])) + # chi = jnp.power(x, 1.0 / p["alpha"]) / (1.0 + jnp.power(x, 1.0 / p["alpha"])) + return jsp.betainc(c0 - q0, q0, chi) - for models with finite mass. - See Eq. 44 in Zhao (1996). +@ft.partial(jax.jit) +def _rho0(p: gt.Params, r_ref: gt.Sz0 | None = None) -> gt.FloatSz0: + """Compute the normalization density for the Zhao profile. + + This computes the normalization constant rho_0 (called C in Zhao 1996) for the Zhao + density profile. The normalization is set that the mass parameter is the mass + enclosed within ``r_ref``. If no r_ref is specified to this function (as happens in + the default initializer for the ``ZhaoPotential``), it is set to the scale radius, + so the mass parameter is interpreted to be the mass enclosed within the scale + radius. + + This implementation uses the hyp2f1 hypergeometric function instead of the + incomplete beta function because jax (and scipy) only provide the *regularized* + version of the incomplete beta function. This means that it blows up when b <= 0 in + B(a, b, z) because the complete beta function B(a, b) is undefined when b <= 0. The + hyp2f1 function is defined for all values of a, b, and z, so it can handle the case + where b <= 0. """ - denom = alpha * jsp.beta(alpha * (3.0 - gamma), alpha * (beta - 3.0)) - return m / (4.0 * jnp.pi * denom) + r_ref = r_ref if r_ref is not None else p["r_s"] + chi_norm = _r_to_u_chi(p, r_ref)[1] + a = p["alpha"] * (3.0 - p["gamma"]) + b = p["alpha"] * (p["beta"] - 3.0) + denom = (chi_norm**a / a) * jsp.hyp2f1(a, 1.0 - b, a + 1.0, chi_norm) + return p["m"] / (4.0 * jnp.pi * p["alpha"] * denom) @ft.partial(jax.jit) @@ -128,30 +150,43 @@ def _cpq(a: gt.Sz0, b: gt.Sz0, g: gt.Sz0) -> tuple[gt.Sz0, gt.Sz0, gt.Sz0]: @ft.partial(jax.jit) -def _r_to_xz(p: gt.Params, r: gt.Sz0) -> gt.FloatSz0: - """Convert radius to the variable z used in the potential.""" - x = r / p["r_s"] - return x, x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"])) +def _r_to_u_chi(p: gt.Params, r: gt.Sz0) -> gt.FloatSz0: + r"""Convert radius to u and chi variables defined below. + + .. math:: + + u = r / r_s + chi = \frac{u^{1/\alpha}}{1 + u^{1/\alpha}} + + """ + uu = r / p["r_s"] + return uu, uu ** (1.0 / p["alpha"]) / (1.0 + uu ** (1.0 / p["alpha"])) @ft.partial(jax.jit) def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: """Spherical density profile for double power-law Zhao model.""" - x = r / p["r_s"] + uu = r / p["r_s"] alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"] - rho0 = _rho0(p["m"], alpha, beta, gamma) + rho0 = _rho0(p) b = (beta - gamma) * alpha - return (rho0 / (p["r_s"] ** 3)) / x**gamma / (1.0 + x ** (1.0 / alpha)) ** b + return rho0 / (p["r_s"] ** 3) / uu**gamma / (1.0 + uu ** (1.0 / alpha)) ** b @ft.partial(jax.jit) -def mass_enclosed(p: gt.Params, z: gt.Sz0) -> gt.Sz0: +def mass_enclosed(p: gt.Params, r: gt.Sz0) -> gt.Sz0: a, b, g = p["alpha"], p["beta"], p["gamma"] - rho0 = _rho0(p["m"], a, b, g) - *_, q0 = _cpq(a, b, g) - aM = a * (3.0 - g) - return 4.0 * jnp.pi * rho0 * a * (jsp.beta(aM, q0) * jsp.betainc(aM, q0, z)) + _, chi = _r_to_u_chi(p, r) + rho0 = _rho0(p) + c0, _, q0 = _cpq(a, b, g) + return ( + 4.0 + * jnp.pi + * rho0 + * a + * (jsp.beta(c0 - q0, q0) * jsp.betainc(c0 - q0, q0, chi)) + ) @ft.partial(jax.jit) @@ -164,26 +199,26 @@ def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: """ a, b, g = p["alpha"], p["beta"], p["gamma"] - # z here is what Zhao calls "chi": - x, z = _r_to_xz(p, r) + uu, chi = _r_to_u_chi(p, r) # Special case the Jaffe potential, where there is an "inf - inf" below is_jaffe = (a == 1.0) & (b == 4.0) & (g == 2.0) def Phi_jaffe() -> gt.Sz0: - return -p["G"] * p["m"] / p["r_s"] * jnp.log1p(1.0 / x) + # Note: the extra factor of 2 is because m is mass enclosed in r_s, not total + return -p["G"] * 2 * p["m"] / p["r_s"] * jnp.log1p(1.0 / uu) - rho0 = _rho0(p["m"], a, b, g) + rho0 = _rho0(p) c0, p0, _ = _cpq(a, b, g) # Left term in Eq. 7 - term_l = mass_enclosed(p, z) + term_l = mass_enclosed(p, r) # Right term in Eq. 7 eps = jnp.sqrt(jnp.finfo(r.dtype).eps) p0_safe = jnp.where(p0 <= 0, eps, p0) logB = jsp.betaln(p0_safe, c0 - p0) - log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, z)) + log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, chi)) term_r = 4.0 * jnp.pi * rho0 * a / p["r_s"] * jnp.exp(logB + log1mI) def Phi_general() -> gt.Sz0: From 5456da94a1a178ede33f5601ac6abc7165515593 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Thu, 14 Aug 2025 16:13:02 -0400 Subject: [PATCH 6/8] add a constructor to enable creating from a total mass (for finite mass settings of power-law indices) --- src/galax/potential/_src/builtin/zhao.py | 50 ++++++++++++++++++- .../functional/potential/builtin/test_zhao.py | 29 +++++------ 2 files changed, 64 insertions(+), 15 deletions(-) diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py index 7e2a285ab..48446a8de 100644 --- a/src/galax/potential/_src/builtin/zhao.py +++ b/src/galax/potential/_src/builtin/zhao.py @@ -102,6 +102,55 @@ def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: } return density(p, r) + # =========================================== + # Constructors + + @classmethod + def from_m_tot( + cls, + m_tot: gt.Sz0 | u.Quantity["mass"], + r_s: gt.Sz0 | u.Quantity["length"], + alpha: gt.Sz0 | u.Quantity["dimensionless"], + beta: gt.Sz0 | u.Quantity["dimensionless"], + gamma: gt.Sz0 | u.Quantity["dimensionless"], + *, + units: u.AbstractUnitSystem | str = "galactic", + constants: ImmutableMap[str, u.AbstractQuantity] = default_constants, + ) -> "ZhaoPotential": + """Create a Zhao potential from a total mass and scale radius. + + Note: This is only possible when beta > 3, when the model has finite mass. + + Parameters + ---------- + m_tot + Total mass of the halo. + r_s + Scale radius of the halo. + alpha + Inner slope of the halo density profile. + beta + Outer slope of the halo density profile. + gamma + Transition slope of the halo density profile. + units (optional) + Unit system to use for the potential. + constants (optional) + Physical constants to use for the potential. + """ + usys = u.unitsystem(units) + params = { + "r_s": u.ustrip(usys["length"], r_s) if hasattr(r_s, "unit") else r_s, + "alpha": alpha, + "beta": beta, + "gamma": gamma, + } + return cls( + m=m_tot * _total_mass_factor(params, params["r_s"]), + **params, + units=units, + constants=constants, + ) @ft.partial(jax.jit) @@ -110,7 +159,6 @@ def _total_mass_factor(p: gt.Params, r_ref: gt.Sz0) -> gt.FloatSz0: c0, _, q0 = _cpq(p["alpha"], p["beta"], p["gamma"]) x = r_ref / p["r_s"] chi = x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"])) - # chi = jnp.power(x, 1.0 / p["alpha"]) / (1.0 + jnp.power(x, 1.0 / p["alpha"])) return jsp.betainc(c0 - q0, q0, chi) diff --git a/tests/functional/potential/builtin/test_zhao.py b/tests/functional/potential/builtin/test_zhao.py index 32fbe9c50..a5ae7e69c 100644 --- a/tests/functional/potential/builtin/test_zhao.py +++ b/tests/functional/potential/builtin/test_zhao.py @@ -9,13 +9,15 @@ check_funcs = ["potential", "gradient", "density", "hessian"] -# settings of (alpha, beta, gamma) and correspondence to known models: -abg_pot = { - (1, 4, 1): gp.HernquistPotential, - (1, 4, 2): gp.JaffePotential, - (1 / 2, 5, 0): gp.PlummerPotential, - # (1, 3, 1): gp.NFWPotential, -} +# Settings of (alpha, beta, gamma) and correspondence to known models +abg_pot_finite = [ + ((1, 4, 1), gp.HernquistPotential), + ((1, 4, 2), gp.JaffePotential), + ((1 / 2, 5, 0), gp.PlummerPotential), +] +abg_pot_infinite = [ + ((1, 3, 1), gp.NFWPotential), +] @pytest.fixture @@ -27,19 +29,18 @@ def xyz(): @pytest.mark.parametrize("func_name", check_funcs) -@pytest.mark.parametrize(("abg", "OtherPotential"), list(abg_pot.items())) -def test_zhao_against_correspondences(func_name, abg, OtherPotential, xyz): - zhao = ZhaoPotential( - m=u.Quantity(1.3e11, "Msun"), +@pytest.mark.parametrize(("abg", "OtherPotential"), abg_pot_finite) +def test_zhao_against_finite_mass_correspondences(func_name, abg, OtherPotential, xyz): + m_tot = u.Quantity(1.3e11, "Msun") + zhao = ZhaoPotential.from_m_tot( + m_tot=m_tot, r_s=u.Quantity(8.1, "kpc"), alpha=abg[0], beta=abg[1], gamma=abg[2], units="galactic", ) - other = OtherPotential( - m_tot=zhao.parameters["m"], r_s=zhao.parameters["r_s"], units="galactic" - ) + other = OtherPotential(m_tot=m_tot, r_s=zhao.parameters["r_s"], units="galactic") zhao_result = getattr(zhao, func_name)(xyz, u.Quantity(0.0, "Myr")) other_result = getattr(other, func_name)(xyz, u.Quantity(0.0, "Myr")) From 87bd690088dc7907c033f62a442a419b21d255a3 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Thu, 14 Aug 2025 16:28:08 -0400 Subject: [PATCH 7/8] error if beta <= 3 in mtot constructor --- src/galax/potential/_src/builtin/zhao.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py index 48446a8de..8bd23a92d 100644 --- a/src/galax/potential/_src/builtin/zhao.py +++ b/src/galax/potential/_src/builtin/zhao.py @@ -138,6 +138,7 @@ def from_m_tot( constants (optional) Physical constants to use for the potential. """ + beta = eqx.error_if(beta, beta <= 3.0, "Beta must be >3 to have finite mass.") usys = u.unitsystem(units) params = { "r_s": u.ustrip(usys["length"], r_s) if hasattr(r_s, "unit") else r_s, From c2b05f88d5f91106ae3c69b16e80f56e73a3219b Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan <583379+adrn@users.noreply.github.com> Date: Thu, 14 Aug 2025 18:59:00 -0400 Subject: [PATCH 8/8] fix test values --- tests/unit/potential/builtin/test_zhao.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/unit/potential/builtin/test_zhao.py b/tests/unit/potential/builtin/test_zhao.py index 0708a008e..329e711cb 100644 --- a/tests/unit/potential/builtin/test_zhao.py +++ b/tests/unit/potential/builtin/test_zhao.py @@ -130,19 +130,18 @@ def fields_( # ========================================================================== def test_potential(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: - expect = u.Quantity(-1.07419698, unit="kpc2 / Myr2") - print(pot.parameters) + expect = u.Quantity(-2.83144346, unit="kpc2 / Myr2") assert jnp.isclose( pot.potential(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) ) def test_gradient(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: - expect = u.Quantity([0.06671604, 0.13343208, 0.20014812], "kpc / Myr2") + expect = u.Quantity([0.1758548, 0.35170961, 0.52756441], "kpc / Myr2") got = pot.gradient(x, t=0) assert jnp.allclose(got, expect, atol=u.Quantity(1e-8, expect.unit)) def test_density(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: - expect = u.Quantity(3.39060592e08, "solMass / kpc3") + expect = u.Quantity(8.93719599e08, "solMass / kpc3") assert jnp.isclose( pot.density(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) ) @@ -150,9 +149,9 @@ def test_density(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: def test_hessian(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: expect = u.Quantity( [ - [0.05378882, -0.02585444, -0.03878166], - [-0.02585444, 0.01500716, -0.07756332], - [-0.03878166, -0.07756332, -0.04962894], + [0.14178033, -0.06814894, -0.10222341], + [-0.06814894, 0.03955692, -0.20444682], + [-0.10222341, -0.20444682, -0.13081543], ], "1/Myr2", ) @@ -167,9 +166,9 @@ def test_tidal_tensor(self, pot: gp.AbstractPotential, x: gt.QuSz3) -> None: """Test the `AbstractPotential.tidal_tensor` method.""" expect = u.Quantity( [ - [0.04739981, -0.02585444, -0.03878166], - [-0.02585444, 0.00861815, -0.07756332], - [-0.03878166, -0.07756332, -0.05601795], + [0.12493972, -0.06814894, -0.10222341], + [-0.06814894, 0.02271631, -0.20444682], + [-0.10222341, -0.20444682, -0.14765604], ], "1/Myr2", )