Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 276 additions & 0 deletions src/galax/potential/_src/builtin/zhao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
__all__ = ["ZhaoPotential"]

import functools as ft
from dataclasses import KW_ONLY
from typing import final

import equinox as eqx
import jax
import jax.scipy.special as jsp

import quaxed.numpy as jnp
import unxt as u
from xmmutablemap import ImmutableMap

import galax._custom_types as gt
from galax.potential._src.base import default_constants
from galax.potential._src.base_single import AbstractSinglePotential
from galax.potential._src.params.base import AbstractParameter
from galax.potential._src.params.field import ParameterField
from galax.potential._src.utils import r_spherical


@final
class ZhaoPotential(AbstractSinglePotential):
r"""Zhao (1996) double power-law potential.

This model represents a double power law in the density, with an inner slope
:math:`\gamma` and an outer slope :math:`\beta`, but with a third parameter
:math:`\alpha` that controls the width of the transition region between the two
power laws.

This model has a finite total mass for :math:`\beta > 3`. The other power-law
parameters should satisfy :math:`\alpha > 0` and :math:`0 \leq \gamma < 3`.

This model also reduces to a number of well-known analytic forms for certain values
of the parameters (reproduced from Table 1 of Zhao 1996):
- :math:`(\alpha, \beta, \gamma) = (1, 4, 1)`: Hernquist model
- :math:`(\alpha, \beta, \gamma) = (1, 4, 2)`: Jaffe model
- :math:`(\alpha, \beta, \gamma) = (1/2, 5, 0)`: Plummer model
- :math:`(\alpha, \beta, \gamma) = (1, 3, 1)`: NFW model
- :math:`(\alpha, \beta, \gamma) = (1, 3, \gamma)`: Generalized NFW model
"""

m: AbstractParameter = ParameterField(
dimensions="mass",
doc=(
"Scale mass parameter. This is equivalent to the mass enclosed within the "
"scale radius. When beta > 3, the model has finite mass, but when beta <= 3"
" the total mass is infinite."
),
) # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length", doc="Scale radius.") # type: ignore[assignment]

alpha: AbstractParameter = ParameterField(
dimensions="dimensionless", doc="Transition width (alpha > 0)."
) # type: ignore[assignment]
beta: AbstractParameter = ParameterField(
dimensions="dimensionless", doc="Outer slope (finite mass when beta > 3)."
) # type: ignore[assignment]
gamma: AbstractParameter = ParameterField(
dimensions="dimensionless", doc="Inner slope (0 <= gamma < 3)."
) # type: ignore[assignment]

_: KW_ONLY
units: u.AbstractUnitSystem = eqx.field(converter=u.unitsystem, static=True)
constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field(
default=default_constants, converter=ImmutableMap
)

@ft.partial(jax.jit)
def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0:
r = r_spherical(xyz, self.units["length"])
t = u.Quantity.from_(t, self.units["time"])

ulen = self.units["length"]
umass = self.units["mass"]
udim = self.units["dimensionless"]
p = {
"G": self.constants["G"].value,
"m": self.m(t, ustrip=umass),
"r_s": self.r_s(t, ustrip=ulen),
"alpha": self.alpha(t, ustrip=udim),
"beta": self.beta(t, ustrip=udim),
"gamma": self.gamma(t, ustrip=udim),
}
return potential(p, r)

@ft.partial(jax.jit)
def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0:
r = r_spherical(xyz, self.units["length"])
t = u.Quantity.from_(t, self.units["time"])

ulen = self.units["length"]
umass = self.units["mass"]
udim = self.units["dimensionless"]
p = {
"m": self.m(t, ustrip=umass),
"r_s": self.r_s(t, ustrip=ulen),
"alpha": self.alpha(t, ustrip=udim),
"beta": self.beta(t, ustrip=udim),
"gamma": self.gamma(t, ustrip=udim),
}
return density(p, r)

# ===========================================
# Constructors

@classmethod
def from_m_tot(
cls,
m_tot: gt.Sz0 | u.Quantity["mass"],
r_s: gt.Sz0 | u.Quantity["length"],
alpha: gt.Sz0 | u.Quantity["dimensionless"],
beta: gt.Sz0 | u.Quantity["dimensionless"],
gamma: gt.Sz0 | u.Quantity["dimensionless"],
*,
units: u.AbstractUnitSystem | str = "galactic",
constants: ImmutableMap[str, u.AbstractQuantity] = default_constants,
) -> "ZhaoPotential":
"""Create a Zhao potential from a total mass and scale radius.

Note: This is only possible when beta > 3, when the model has finite mass.

Parameters
----------
m_tot
Total mass of the halo.
r_s
Scale radius of the halo.
alpha
Inner slope of the halo density profile.
beta
Outer slope of the halo density profile.
gamma
Transition slope of the halo density profile.
units (optional)
Unit system to use for the potential.
constants (optional)
Physical constants to use for the potential.
"""
beta = eqx.error_if(beta, beta <= 3.0, "Beta must be >3 to have finite mass.")
usys = u.unitsystem(units)
params = {
"r_s": u.ustrip(usys["length"], r_s) if hasattr(r_s, "unit") else r_s,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can use u.ustrip(AllowValue

"alpha": alpha,
"beta": beta,
"gamma": gamma,
}
return cls(
m=m_tot * _total_mass_factor(params, params["r_s"]),
**params,
units=units,
constants=constants,
)


@ft.partial(jax.jit)
def _total_mass_factor(p: gt.Params, r_ref: gt.Sz0) -> gt.FloatSz0:
"""Compute the total mass factor for the Zhao profile."""
c0, _, q0 = _cpq(p["alpha"], p["beta"], p["gamma"])
x = r_ref / p["r_s"]
chi = x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"]))
return jsp.betainc(c0 - q0, q0, chi)


@ft.partial(jax.jit)
def _rho0(p: gt.Params, r_ref: gt.Sz0 | None = None) -> gt.FloatSz0:
"""Compute the normalization density for the Zhao profile.

This computes the normalization constant rho_0 (called C in Zhao 1996) for the Zhao
density profile. The normalization is set that the mass parameter is the mass
enclosed within ``r_ref``. If no r_ref is specified to this function (as happens in
the default initializer for the ``ZhaoPotential``), it is set to the scale radius,
so the mass parameter is interpreted to be the mass enclosed within the scale
radius.

This implementation uses the hyp2f1 hypergeometric function instead of the
incomplete beta function because jax (and scipy) only provide the *regularized*
version of the incomplete beta function. This means that it blows up when b <= 0 in
B(a, b, z) because the complete beta function B(a, b) is undefined when b <= 0. The
hyp2f1 function is defined for all values of a, b, and z, so it can handle the case
where b <= 0.
"""
r_ref = r_ref if r_ref is not None else p["r_s"]
chi_norm = _r_to_u_chi(p, r_ref)[1]
a = p["alpha"] * (3.0 - p["gamma"])
b = p["alpha"] * (p["beta"] - 3.0)
denom = (chi_norm**a / a) * jsp.hyp2f1(a, 1.0 - b, a + 1.0, chi_norm)
return p["m"] / (4.0 * jnp.pi * p["alpha"] * denom)


@ft.partial(jax.jit)
def _cpq(a: gt.Sz0, b: gt.Sz0, g: gt.Sz0) -> tuple[gt.Sz0, gt.Sz0, gt.Sz0]:
"""Constants defined in appendix of Zhao (1996)."""
c0 = a * (b - g)
p0 = a * (2.0 - g)
q0 = a * (b - 3.0)
return c0, p0, q0


@ft.partial(jax.jit)
def _r_to_u_chi(p: gt.Params, r: gt.Sz0) -> gt.FloatSz0:
r"""Convert radius to u and chi variables defined below.

.. math::

u = r / r_s
chi = \frac{u^{1/\alpha}}{1 + u^{1/\alpha}}

"""
uu = r / p["r_s"]
return uu, uu ** (1.0 / p["alpha"]) / (1.0 + uu ** (1.0 / p["alpha"]))


@ft.partial(jax.jit)
def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0:
"""Spherical density profile for double power-law Zhao model."""
uu = r / p["r_s"]
alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"]
rho0 = _rho0(p)

b = (beta - gamma) * alpha
return rho0 / (p["r_s"] ** 3) / uu**gamma / (1.0 + uu ** (1.0 / alpha)) ** b


@ft.partial(jax.jit)
def mass_enclosed(p: gt.Params, r: gt.Sz0) -> gt.Sz0:
a, b, g = p["alpha"], p["beta"], p["gamma"]
_, chi = _r_to_u_chi(p, r)
rho0 = _rho0(p)
c0, _, q0 = _cpq(a, b, g)
return (
4.0
* jnp.pi
* rho0
* a
* (jsp.beta(c0 - q0, q0) * jsp.betainc(c0 - q0, q0, chi))
)


@ft.partial(jax.jit)
def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0:
r"""Spherical potential for double power-law Zhao model.

See Eq. 6 and 7 in Zhao (1996).

This function uses the variable z for what Zhao called :math:`\chi`.
"""
a, b, g = p["alpha"], p["beta"], p["gamma"]

uu, chi = _r_to_u_chi(p, r)

# Special case the Jaffe potential, where there is an "inf - inf" below
is_jaffe = (a == 1.0) & (b == 4.0) & (g == 2.0)

def Phi_jaffe() -> gt.Sz0:
# Note: the extra factor of 2 is because m is mass enclosed in r_s, not total
return -p["G"] * 2 * p["m"] / p["r_s"] * jnp.log1p(1.0 / uu)

rho0 = _rho0(p)
c0, p0, _ = _cpq(a, b, g)

# Left term in Eq. 7
term_l = mass_enclosed(p, r)

# Right term in Eq. 7
eps = jnp.sqrt(jnp.finfo(r.dtype).eps)
p0_safe = jnp.where(p0 <= 0, eps, p0)
logB = jsp.betaln(p0_safe, c0 - p0)
log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, chi))
term_r = 4.0 * jnp.pi * rho0 * a / p["r_s"] * jnp.exp(logB + log1mI)

def Phi_general() -> gt.Sz0:
return -p["G"] * (term_l / r + term_r)

return jax.lax.cond(is_jaffe, Phi_jaffe, Phi_general)
50 changes: 50 additions & 0 deletions tests/functional/potential/builtin/test_zhao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import jax
import pytest

import quaxed.numpy as jnp
import unxt as u

import galax.potential as gp
from galax.potential._src.builtin.zhao import ZhaoPotential

check_funcs = ["potential", "gradient", "density", "hessian"]

# Settings of (alpha, beta, gamma) and correspondence to known models
abg_pot_finite = [
((1, 4, 1), gp.HernquistPotential),
((1, 4, 2), gp.JaffePotential),
((1 / 2, 5, 0), gp.PlummerPotential),
]
abg_pot_infinite = [
((1, 3, 1), gp.NFWPotential),
]


@pytest.fixture
def xyz():
test_r = jnp.geomspace(1e-3, 1e2, 128)
rand_uvecs = jax.random.normal(jax.random.key(42), shape=(test_r.size, 3))
rand_uvecs = rand_uvecs / jnp.linalg.norm(rand_uvecs, axis=-1, keepdims=True)
return u.Quantity(test_r[:, None] * rand_uvecs, "kpc")


@pytest.mark.parametrize("func_name", check_funcs)
@pytest.mark.parametrize(("abg", "OtherPotential"), abg_pot_finite)
def test_zhao_against_finite_mass_correspondences(func_name, abg, OtherPotential, xyz):
m_tot = u.Quantity(1.3e11, "Msun")
zhao = ZhaoPotential.from_m_tot(
m_tot=m_tot,
r_s=u.Quantity(8.1, "kpc"),
alpha=abg[0],
beta=abg[1],
gamma=abg[2],
units="galactic",
)
other = OtherPotential(m_tot=m_tot, r_s=zhao.parameters["r_s"], units="galactic")

zhao_result = getattr(zhao, func_name)(xyz, u.Quantity(0.0, "Myr"))
other_result = getattr(other, func_name)(xyz, u.Quantity(0.0, "Myr"))

assert jnp.allclose(
zhao_result, other_result, rtol=1e-8, atol=u.Quantity(1e-6, zhao_result.unit)
)
Loading
Loading