diff --git a/src/galax/potential/_src/builtin/zhao.py b/src/galax/potential/_src/builtin/zhao.py new file mode 100644 index 000000000..8bd23a92d --- /dev/null +++ b/src/galax/potential/_src/builtin/zhao.py @@ -0,0 +1,276 @@ +__all__ = ["ZhaoPotential"] + +import functools as ft +from dataclasses import KW_ONLY +from typing import final + +import equinox as eqx +import jax +import jax.scipy.special as jsp + +import quaxed.numpy as jnp +import unxt as u +from xmmutablemap import ImmutableMap + +import galax._custom_types as gt +from galax.potential._src.base import default_constants +from galax.potential._src.base_single import AbstractSinglePotential +from galax.potential._src.params.base import AbstractParameter +from galax.potential._src.params.field import ParameterField +from galax.potential._src.utils import r_spherical + + +@final +class ZhaoPotential(AbstractSinglePotential): + r"""Zhao (1996) double power-law potential. + + This model represents a double power law in the density, with an inner slope + :math:`\gamma` and an outer slope :math:`\beta`, but with a third parameter + :math:`\alpha` that controls the width of the transition region between the two + power laws. + + This model has a finite total mass for :math:`\beta > 3`. The other power-law + parameters should satisfy :math:`\alpha > 0` and :math:`0 \leq \gamma < 3`. + + This model also reduces to a number of well-known analytic forms for certain values + of the parameters (reproduced from Table 1 of Zhao 1996): + - :math:`(\alpha, \beta, \gamma) = (1, 4, 1)`: Hernquist model + - :math:`(\alpha, \beta, \gamma) = (1, 4, 2)`: Jaffe model + - :math:`(\alpha, \beta, \gamma) = (1/2, 5, 0)`: Plummer model + - :math:`(\alpha, \beta, \gamma) = (1, 3, 1)`: NFW model + - :math:`(\alpha, \beta, \gamma) = (1, 3, \gamma)`: Generalized NFW model + """ + + m: AbstractParameter = ParameterField( + dimensions="mass", + doc=( + "Scale mass parameter. This is equivalent to the mass enclosed within the " + "scale radius. When beta > 3, the model has finite mass, but when beta <= 3" + " the total mass is infinite." + ), + ) # type: ignore[assignment] + r_s: AbstractParameter = ParameterField(dimensions="length", doc="Scale radius.") # type: ignore[assignment] + + alpha: AbstractParameter = ParameterField( + dimensions="dimensionless", doc="Transition width (alpha > 0)." + ) # type: ignore[assignment] + beta: AbstractParameter = ParameterField( + dimensions="dimensionless", doc="Outer slope (finite mass when beta > 3)." + ) # type: ignore[assignment] + gamma: AbstractParameter = ParameterField( + dimensions="dimensionless", doc="Inner slope (0 <= gamma < 3)." + ) # type: ignore[assignment] + + _: KW_ONLY + units: u.AbstractUnitSystem = eqx.field(converter=u.unitsystem, static=True) + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) + + @ft.partial(jax.jit) + def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + ulen = self.units["length"] + umass = self.units["mass"] + udim = self.units["dimensionless"] + p = { + "G": self.constants["G"].value, + "m": self.m(t, ustrip=umass), + "r_s": self.r_s(t, ustrip=ulen), + "alpha": self.alpha(t, ustrip=udim), + "beta": self.beta(t, ustrip=udim), + "gamma": self.gamma(t, ustrip=udim), + } + return potential(p, r) + + @ft.partial(jax.jit) + def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: + r = r_spherical(xyz, self.units["length"]) + t = u.Quantity.from_(t, self.units["time"]) + + ulen = self.units["length"] + umass = self.units["mass"] + udim = self.units["dimensionless"] + p = { + "m": self.m(t, ustrip=umass), + "r_s": self.r_s(t, ustrip=ulen), + "alpha": self.alpha(t, ustrip=udim), + "beta": self.beta(t, ustrip=udim), + "gamma": self.gamma(t, ustrip=udim), + } + return density(p, r) + + # =========================================== + # Constructors + + @classmethod + def from_m_tot( + cls, + m_tot: gt.Sz0 | u.Quantity["mass"], + r_s: gt.Sz0 | u.Quantity["length"], + alpha: gt.Sz0 | u.Quantity["dimensionless"], + beta: gt.Sz0 | u.Quantity["dimensionless"], + gamma: gt.Sz0 | u.Quantity["dimensionless"], + *, + units: u.AbstractUnitSystem | str = "galactic", + constants: ImmutableMap[str, u.AbstractQuantity] = default_constants, + ) -> "ZhaoPotential": + """Create a Zhao potential from a total mass and scale radius. + + Note: This is only possible when beta > 3, when the model has finite mass. + + Parameters + ---------- + m_tot + Total mass of the halo. + r_s + Scale radius of the halo. + alpha + Inner slope of the halo density profile. + beta + Outer slope of the halo density profile. + gamma + Transition slope of the halo density profile. + units (optional) + Unit system to use for the potential. + constants (optional) + Physical constants to use for the potential. + """ + beta = eqx.error_if(beta, beta <= 3.0, "Beta must be >3 to have finite mass.") + usys = u.unitsystem(units) + params = { + "r_s": u.ustrip(usys["length"], r_s) if hasattr(r_s, "unit") else r_s, + "alpha": alpha, + "beta": beta, + "gamma": gamma, + } + return cls( + m=m_tot * _total_mass_factor(params, params["r_s"]), + **params, + units=units, + constants=constants, + ) + + +@ft.partial(jax.jit) +def _total_mass_factor(p: gt.Params, r_ref: gt.Sz0) -> gt.FloatSz0: + """Compute the total mass factor for the Zhao profile.""" + c0, _, q0 = _cpq(p["alpha"], p["beta"], p["gamma"]) + x = r_ref / p["r_s"] + chi = x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"])) + return jsp.betainc(c0 - q0, q0, chi) + + +@ft.partial(jax.jit) +def _rho0(p: gt.Params, r_ref: gt.Sz0 | None = None) -> gt.FloatSz0: + """Compute the normalization density for the Zhao profile. + + This computes the normalization constant rho_0 (called C in Zhao 1996) for the Zhao + density profile. The normalization is set that the mass parameter is the mass + enclosed within ``r_ref``. If no r_ref is specified to this function (as happens in + the default initializer for the ``ZhaoPotential``), it is set to the scale radius, + so the mass parameter is interpreted to be the mass enclosed within the scale + radius. + + This implementation uses the hyp2f1 hypergeometric function instead of the + incomplete beta function because jax (and scipy) only provide the *regularized* + version of the incomplete beta function. This means that it blows up when b <= 0 in + B(a, b, z) because the complete beta function B(a, b) is undefined when b <= 0. The + hyp2f1 function is defined for all values of a, b, and z, so it can handle the case + where b <= 0. + """ + r_ref = r_ref if r_ref is not None else p["r_s"] + chi_norm = _r_to_u_chi(p, r_ref)[1] + a = p["alpha"] * (3.0 - p["gamma"]) + b = p["alpha"] * (p["beta"] - 3.0) + denom = (chi_norm**a / a) * jsp.hyp2f1(a, 1.0 - b, a + 1.0, chi_norm) + return p["m"] / (4.0 * jnp.pi * p["alpha"] * denom) + + +@ft.partial(jax.jit) +def _cpq(a: gt.Sz0, b: gt.Sz0, g: gt.Sz0) -> tuple[gt.Sz0, gt.Sz0, gt.Sz0]: + """Constants defined in appendix of Zhao (1996).""" + c0 = a * (b - g) + p0 = a * (2.0 - g) + q0 = a * (b - 3.0) + return c0, p0, q0 + + +@ft.partial(jax.jit) +def _r_to_u_chi(p: gt.Params, r: gt.Sz0) -> gt.FloatSz0: + r"""Convert radius to u and chi variables defined below. + + .. math:: + + u = r / r_s + chi = \frac{u^{1/\alpha}}{1 + u^{1/\alpha}} + + """ + uu = r / p["r_s"] + return uu, uu ** (1.0 / p["alpha"]) / (1.0 + uu ** (1.0 / p["alpha"])) + + +@ft.partial(jax.jit) +def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: + """Spherical density profile for double power-law Zhao model.""" + uu = r / p["r_s"] + alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"] + rho0 = _rho0(p) + + b = (beta - gamma) * alpha + return rho0 / (p["r_s"] ** 3) / uu**gamma / (1.0 + uu ** (1.0 / alpha)) ** b + + +@ft.partial(jax.jit) +def mass_enclosed(p: gt.Params, r: gt.Sz0) -> gt.Sz0: + a, b, g = p["alpha"], p["beta"], p["gamma"] + _, chi = _r_to_u_chi(p, r) + rho0 = _rho0(p) + c0, _, q0 = _cpq(a, b, g) + return ( + 4.0 + * jnp.pi + * rho0 + * a + * (jsp.beta(c0 - q0, q0) * jsp.betainc(c0 - q0, q0, chi)) + ) + + +@ft.partial(jax.jit) +def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: + r"""Spherical potential for double power-law Zhao model. + + See Eq. 6 and 7 in Zhao (1996). + + This function uses the variable z for what Zhao called :math:`\chi`. + """ + a, b, g = p["alpha"], p["beta"], p["gamma"] + + uu, chi = _r_to_u_chi(p, r) + + # Special case the Jaffe potential, where there is an "inf - inf" below + is_jaffe = (a == 1.0) & (b == 4.0) & (g == 2.0) + + def Phi_jaffe() -> gt.Sz0: + # Note: the extra factor of 2 is because m is mass enclosed in r_s, not total + return -p["G"] * 2 * p["m"] / p["r_s"] * jnp.log1p(1.0 / uu) + + rho0 = _rho0(p) + c0, p0, _ = _cpq(a, b, g) + + # Left term in Eq. 7 + term_l = mass_enclosed(p, r) + + # Right term in Eq. 7 + eps = jnp.sqrt(jnp.finfo(r.dtype).eps) + p0_safe = jnp.where(p0 <= 0, eps, p0) + logB = jsp.betaln(p0_safe, c0 - p0) + log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, chi)) + term_r = 4.0 * jnp.pi * rho0 * a / p["r_s"] * jnp.exp(logB + log1mI) + + def Phi_general() -> gt.Sz0: + return -p["G"] * (term_l / r + term_r) + + return jax.lax.cond(is_jaffe, Phi_jaffe, Phi_general) diff --git a/tests/functional/potential/builtin/test_zhao.py b/tests/functional/potential/builtin/test_zhao.py new file mode 100644 index 000000000..a5ae7e69c --- /dev/null +++ b/tests/functional/potential/builtin/test_zhao.py @@ -0,0 +1,50 @@ +import jax +import pytest + +import quaxed.numpy as jnp +import unxt as u + +import galax.potential as gp +from galax.potential._src.builtin.zhao import ZhaoPotential + +check_funcs = ["potential", "gradient", "density", "hessian"] + +# Settings of (alpha, beta, gamma) and correspondence to known models +abg_pot_finite = [ + ((1, 4, 1), gp.HernquistPotential), + ((1, 4, 2), gp.JaffePotential), + ((1 / 2, 5, 0), gp.PlummerPotential), +] +abg_pot_infinite = [ + ((1, 3, 1), gp.NFWPotential), +] + + +@pytest.fixture +def xyz(): + test_r = jnp.geomspace(1e-3, 1e2, 128) + rand_uvecs = jax.random.normal(jax.random.key(42), shape=(test_r.size, 3)) + rand_uvecs = rand_uvecs / jnp.linalg.norm(rand_uvecs, axis=-1, keepdims=True) + return u.Quantity(test_r[:, None] * rand_uvecs, "kpc") + + +@pytest.mark.parametrize("func_name", check_funcs) +@pytest.mark.parametrize(("abg", "OtherPotential"), abg_pot_finite) +def test_zhao_against_finite_mass_correspondences(func_name, abg, OtherPotential, xyz): + m_tot = u.Quantity(1.3e11, "Msun") + zhao = ZhaoPotential.from_m_tot( + m_tot=m_tot, + r_s=u.Quantity(8.1, "kpc"), + alpha=abg[0], + beta=abg[1], + gamma=abg[2], + units="galactic", + ) + other = OtherPotential(m_tot=m_tot, r_s=zhao.parameters["r_s"], units="galactic") + + zhao_result = getattr(zhao, func_name)(xyz, u.Quantity(0.0, "Myr")) + other_result = getattr(other, func_name)(xyz, u.Quantity(0.0, "Myr")) + + assert jnp.allclose( + zhao_result, other_result, rtol=1e-8, atol=u.Quantity(1e-6, zhao_result.unit) + ) diff --git a/tests/unit/potential/builtin/test_zhao.py b/tests/unit/potential/builtin/test_zhao.py new file mode 100644 index 000000000..329e711cb --- /dev/null +++ b/tests/unit/potential/builtin/test_zhao.py @@ -0,0 +1,177 @@ +from typing import Any, ClassVar + +import pytest + +import quaxed.numpy as jnp +import unxt as u + +import galax._custom_types as gt +import galax.potential as gp +from ..param.test_field import ParameterFieldMixin +from ..test_core import AbstractSinglePotential_Test +from .test_common import ParameterMMixin, ParameterRSMixin +from galax.potential._src.builtin.zhao import ZhaoPotential + + +class AlphaParameterMixin(ParameterFieldMixin): + """Test the alpha parameter.""" + + @pytest.fixture(scope="class") + def field_alpha(self) -> u.Quantity["dimensionless"]: + return u.Quantity(0.9, "") + + # ===================================================== + + def test_alpha_constant(self, pot_cls, fields): + """Test the `alpha` parameter.""" + fields["alpha"] = u.Quantity(1.0, "") + pot = pot_cls(**fields) + assert pot.alpha(t=u.Quantity(0, "Myr")) == u.Quantity(1.0, "") + + def test_alpha_userfunc(self, pot_cls, fields): + """Test the `alpha` parameter.""" + + def cos_alpha(t: u.Quantity["time"]) -> u.Quantity[""]: + return u.Quantity(0.5 * jnp.cos(t.ustrip("Myr")) ** 2 + 0.5, "") + + fields["alpha"] = cos_alpha + pot = pot_cls(**fields) + assert pot.alpha(t=u.Quantity(0, "Myr")) == u.Quantity(1.0, "") + + +class BetaParameterMixin(ParameterFieldMixin): + """Test the beta parameter.""" + + @pytest.fixture(scope="class") + def field_beta(self) -> u.Quantity["dimensionless"]: + return u.Quantity(4.31, "") + + # ===================================================== + + def test_beta_constant(self, pot_cls, fields): + """Test the `beta` parameter.""" + fields["beta"] = u.Quantity(3.5, "") + pot = pot_cls(**fields) + assert pot.beta(t=u.Quantity(0, "Myr")) == u.Quantity(3.5, "") + + def test_beta_userfunc(self, pot_cls, fields): + """Test the `beta` parameter.""" + + def cos_beta(t: u.Quantity["time"]) -> u.Quantity[""]: + return u.Quantity(jnp.cos(t.ustrip("Myr")) + 4.2, "") + + fields["beta"] = cos_beta + pot = pot_cls(**fields) + assert pot.beta(t=u.Quantity(0, "Myr")) == u.Quantity(5.2, "") + + +class GammaParameterMixin(ParameterFieldMixin): + """Test the gamma parameter.""" + + @pytest.fixture(scope="class") + def field_gamma(self) -> u.Quantity["dimensionless"]: + return u.Quantity(1.2, "") + + # ===================================================== + + def test_gamma_constant(self, pot_cls, fields): + """Test the `gamma` parameter.""" + fields["gamma"] = u.Quantity(1.5, "") + pot = pot_cls(**fields) + assert pot.gamma(t=u.Quantity(0, "Myr")) == u.Quantity(1.5, "") + + def test_gamma_userfunc(self, pot_cls, fields): + """Test the `gamma` parameter.""" + + def cos_gamma(t: u.Quantity["time"]) -> u.Quantity[""]: + return u.Quantity(0.5 * jnp.cos(t.ustrip("Myr")) + 1.5, "") + + fields["gamma"] = cos_gamma + pot = pot_cls(**fields) + assert pot.gamma(t=u.Quantity(0, "Myr")) == u.Quantity(2.0, "") + + +class TestZhaoPotential( + AbstractSinglePotential_Test, + # Parameters + ParameterMMixin, + ParameterRSMixin, + AlphaParameterMixin, + BetaParameterMixin, + GammaParameterMixin, +): + """Test the `galax.potential.ZhaoPotential` class.""" + + HAS_GALA_COUNTERPART: ClassVar[bool] = False + + @pytest.fixture(scope="class") + def pot_cls(self) -> type[ZhaoPotential]: + return ZhaoPotential + + @pytest.fixture(scope="class") + def fields_( + self, + field_m: u.Quantity, + field_r_s: u.Quantity, + field_alpha: u.Quantity, + field_beta: u.Quantity, + field_gamma: u.Quantity, + field_units: u.AbstractUnitSystem, + ) -> dict[str, Any]: + return { + "m": field_m, + "r_s": field_r_s, + "alpha": field_alpha, + "beta": field_beta, + "gamma": field_gamma, + "units": field_units, + } + + # ========================================================================== + + def test_potential(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(-2.83144346, unit="kpc2 / Myr2") + assert jnp.isclose( + pot.potential(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity([0.1758548, 0.35170961, 0.52756441], "kpc / Myr2") + got = pot.gradient(x, t=0) + assert jnp.allclose(got, expect, atol=u.Quantity(1e-8, expect.unit)) + + def test_density(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity(8.93719599e08, "solMass / kpc3") + assert jnp.isclose( + pot.density(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: ZhaoPotential, x: gt.QuSz3) -> None: + expect = u.Quantity( + [ + [0.14178033, -0.06814894, -0.10222341], + [-0.06814894, 0.03955692, -0.20444682], + [-0.10222341, -0.20444682, -0.13081543], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.hessian(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: gp.AbstractPotential, x: gt.QuSz3) -> None: + """Test the `AbstractPotential.tidal_tensor` method.""" + expect = u.Quantity( + [ + [0.12493972, -0.06814894, -0.10222341], + [-0.06814894, 0.02271631, -0.20444682], + [-0.10222341, -0.20444682, -0.14765604], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=u.Quantity(1e-8, expect.unit) + )