-
Notifications
You must be signed in to change notification settings - Fork 8
Add Zhao potential model #761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adrn
wants to merge
8
commits into
GalacticDynamics:main
Choose a base branch
from
adrn:zhao-potentia
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c0dc270
initial implementation of Zhao model
adrn c13ef5d
add note of potential correspondences
adrn 8824769
some refactor and add a special case for Jaffe to avoid bad gradient
adrn 578e535
add test module for Zhao potential
adrn a4ad0bf
switch interpretation of mass parameter to be mass enclosed within sc…
adrn 5456da9
add a constructor to enable creating from a total mass (for finite ma…
adrn 87bd690
error if beta <= 3 in mtot constructor
adrn c2b05f8
fix test values
adrn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| __all__ = ["ZhaoPotential"] | ||
|
|
||
| import functools as ft | ||
| from dataclasses import KW_ONLY | ||
| from typing import final | ||
|
|
||
| import equinox as eqx | ||
| import jax | ||
| import jax.scipy.special as jsp | ||
|
|
||
| import quaxed.numpy as jnp | ||
| import unxt as u | ||
| from xmmutablemap import ImmutableMap | ||
|
|
||
| import galax._custom_types as gt | ||
| from galax.potential._src.base import default_constants | ||
| from galax.potential._src.base_single import AbstractSinglePotential | ||
| from galax.potential._src.params.base import AbstractParameter | ||
| from galax.potential._src.params.field import ParameterField | ||
| from galax.potential._src.utils import r_spherical | ||
|
|
||
|
|
||
| @final | ||
| class ZhaoPotential(AbstractSinglePotential): | ||
| r"""Zhao (1996) double power-law potential. | ||
|
|
||
| This model represents a double power law in the density, with an inner slope | ||
| :math:`\gamma` and an outer slope :math:`\beta`, but with a third parameter | ||
| :math:`\alpha` that controls the width of the transition region between the two | ||
| power laws. | ||
|
|
||
| This model has a finite total mass for :math:`\beta > 3`. The other power-law | ||
| parameters should satisfy :math:`\alpha > 0` and :math:`0 \leq \gamma < 3`. | ||
|
|
||
| This model also reduces to a number of well-known analytic forms for certain values | ||
| of the parameters (reproduced from Table 1 of Zhao 1996): | ||
| - :math:`(\alpha, \beta, \gamma) = (1, 4, 1)`: Hernquist model | ||
| - :math:`(\alpha, \beta, \gamma) = (1, 4, 2)`: Jaffe model | ||
| - :math:`(\alpha, \beta, \gamma) = (1/2, 5, 0)`: Plummer model | ||
| - :math:`(\alpha, \beta, \gamma) = (1, 3, 1)`: NFW model | ||
| - :math:`(\alpha, \beta, \gamma) = (1, 3, \gamma)`: Generalized NFW model | ||
| """ | ||
|
|
||
| m: AbstractParameter = ParameterField( | ||
| dimensions="mass", | ||
| doc=( | ||
| "Scale mass parameter. This is equivalent to the mass enclosed within the " | ||
| "scale radius. When beta > 3, the model has finite mass, but when beta <= 3" | ||
| " the total mass is infinite." | ||
| ), | ||
| ) # type: ignore[assignment] | ||
| r_s: AbstractParameter = ParameterField(dimensions="length", doc="Scale radius.") # type: ignore[assignment] | ||
|
|
||
| alpha: AbstractParameter = ParameterField( | ||
| dimensions="dimensionless", doc="Transition width (alpha > 0)." | ||
| ) # type: ignore[assignment] | ||
| beta: AbstractParameter = ParameterField( | ||
| dimensions="dimensionless", doc="Outer slope (finite mass when beta > 3)." | ||
| ) # type: ignore[assignment] | ||
| gamma: AbstractParameter = ParameterField( | ||
| dimensions="dimensionless", doc="Inner slope (0 <= gamma < 3)." | ||
| ) # type: ignore[assignment] | ||
|
|
||
| _: KW_ONLY | ||
| units: u.AbstractUnitSystem = eqx.field(converter=u.unitsystem, static=True) | ||
| constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( | ||
| default=default_constants, converter=ImmutableMap | ||
| ) | ||
|
|
||
| @ft.partial(jax.jit) | ||
| def _potential(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BBtSz0: | ||
| r = r_spherical(xyz, self.units["length"]) | ||
| t = u.Quantity.from_(t, self.units["time"]) | ||
|
|
||
| ulen = self.units["length"] | ||
| umass = self.units["mass"] | ||
| udim = self.units["dimensionless"] | ||
| p = { | ||
| "G": self.constants["G"].value, | ||
| "m": self.m(t, ustrip=umass), | ||
| "r_s": self.r_s(t, ustrip=ulen), | ||
| "alpha": self.alpha(t, ustrip=udim), | ||
| "beta": self.beta(t, ustrip=udim), | ||
| "gamma": self.gamma(t, ustrip=udim), | ||
| } | ||
| return potential(p, r) | ||
|
|
||
| @ft.partial(jax.jit) | ||
| def _density(self, xyz: gt.BBtQorVSz3, t: gt.BBtQorVSz0, /) -> gt.BtFloatSz0: | ||
| r = r_spherical(xyz, self.units["length"]) | ||
| t = u.Quantity.from_(t, self.units["time"]) | ||
|
|
||
| ulen = self.units["length"] | ||
| umass = self.units["mass"] | ||
| udim = self.units["dimensionless"] | ||
| p = { | ||
| "m": self.m(t, ustrip=umass), | ||
| "r_s": self.r_s(t, ustrip=ulen), | ||
| "alpha": self.alpha(t, ustrip=udim), | ||
| "beta": self.beta(t, ustrip=udim), | ||
| "gamma": self.gamma(t, ustrip=udim), | ||
| } | ||
| return density(p, r) | ||
|
|
||
| # =========================================== | ||
| # Constructors | ||
|
|
||
| @classmethod | ||
| def from_m_tot( | ||
| cls, | ||
| m_tot: gt.Sz0 | u.Quantity["mass"], | ||
| r_s: gt.Sz0 | u.Quantity["length"], | ||
| alpha: gt.Sz0 | u.Quantity["dimensionless"], | ||
| beta: gt.Sz0 | u.Quantity["dimensionless"], | ||
| gamma: gt.Sz0 | u.Quantity["dimensionless"], | ||
| *, | ||
| units: u.AbstractUnitSystem | str = "galactic", | ||
| constants: ImmutableMap[str, u.AbstractQuantity] = default_constants, | ||
| ) -> "ZhaoPotential": | ||
| """Create a Zhao potential from a total mass and scale radius. | ||
|
|
||
| Note: This is only possible when beta > 3, when the model has finite mass. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| m_tot | ||
| Total mass of the halo. | ||
| r_s | ||
| Scale radius of the halo. | ||
| alpha | ||
| Inner slope of the halo density profile. | ||
| beta | ||
| Outer slope of the halo density profile. | ||
| gamma | ||
| Transition slope of the halo density profile. | ||
| units (optional) | ||
| Unit system to use for the potential. | ||
| constants (optional) | ||
| Physical constants to use for the potential. | ||
| """ | ||
| beta = eqx.error_if(beta, beta <= 3.0, "Beta must be >3 to have finite mass.") | ||
| usys = u.unitsystem(units) | ||
| params = { | ||
| "r_s": u.ustrip(usys["length"], r_s) if hasattr(r_s, "unit") else r_s, | ||
| "alpha": alpha, | ||
| "beta": beta, | ||
| "gamma": gamma, | ||
| } | ||
| return cls( | ||
| m=m_tot * _total_mass_factor(params, params["r_s"]), | ||
| **params, | ||
| units=units, | ||
| constants=constants, | ||
| ) | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def _total_mass_factor(p: gt.Params, r_ref: gt.Sz0) -> gt.FloatSz0: | ||
| """Compute the total mass factor for the Zhao profile.""" | ||
| c0, _, q0 = _cpq(p["alpha"], p["beta"], p["gamma"]) | ||
| x = r_ref / p["r_s"] | ||
| chi = x ** (1.0 / p["alpha"]) / (1.0 + x ** (1.0 / p["alpha"])) | ||
| return jsp.betainc(c0 - q0, q0, chi) | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def _rho0(p: gt.Params, r_ref: gt.Sz0 | None = None) -> gt.FloatSz0: | ||
| """Compute the normalization density for the Zhao profile. | ||
|
|
||
| This computes the normalization constant rho_0 (called C in Zhao 1996) for the Zhao | ||
| density profile. The normalization is set that the mass parameter is the mass | ||
| enclosed within ``r_ref``. If no r_ref is specified to this function (as happens in | ||
| the default initializer for the ``ZhaoPotential``), it is set to the scale radius, | ||
| so the mass parameter is interpreted to be the mass enclosed within the scale | ||
| radius. | ||
|
|
||
| This implementation uses the hyp2f1 hypergeometric function instead of the | ||
| incomplete beta function because jax (and scipy) only provide the *regularized* | ||
| version of the incomplete beta function. This means that it blows up when b <= 0 in | ||
| B(a, b, z) because the complete beta function B(a, b) is undefined when b <= 0. The | ||
| hyp2f1 function is defined for all values of a, b, and z, so it can handle the case | ||
| where b <= 0. | ||
| """ | ||
| r_ref = r_ref if r_ref is not None else p["r_s"] | ||
| chi_norm = _r_to_u_chi(p, r_ref)[1] | ||
| a = p["alpha"] * (3.0 - p["gamma"]) | ||
| b = p["alpha"] * (p["beta"] - 3.0) | ||
| denom = (chi_norm**a / a) * jsp.hyp2f1(a, 1.0 - b, a + 1.0, chi_norm) | ||
| return p["m"] / (4.0 * jnp.pi * p["alpha"] * denom) | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def _cpq(a: gt.Sz0, b: gt.Sz0, g: gt.Sz0) -> tuple[gt.Sz0, gt.Sz0, gt.Sz0]: | ||
| """Constants defined in appendix of Zhao (1996).""" | ||
| c0 = a * (b - g) | ||
| p0 = a * (2.0 - g) | ||
| q0 = a * (b - 3.0) | ||
| return c0, p0, q0 | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def _r_to_u_chi(p: gt.Params, r: gt.Sz0) -> gt.FloatSz0: | ||
| r"""Convert radius to u and chi variables defined below. | ||
|
|
||
| .. math:: | ||
|
|
||
| u = r / r_s | ||
| chi = \frac{u^{1/\alpha}}{1 + u^{1/\alpha}} | ||
|
|
||
| """ | ||
| uu = r / p["r_s"] | ||
| return uu, uu ** (1.0 / p["alpha"]) / (1.0 + uu ** (1.0 / p["alpha"])) | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def density(p: gt.Params, r: gt.Sz0, /) -> gt.FloatSz0: | ||
| """Spherical density profile for double power-law Zhao model.""" | ||
| uu = r / p["r_s"] | ||
| alpha, beta, gamma = p["alpha"], p["beta"], p["gamma"] | ||
| rho0 = _rho0(p) | ||
|
|
||
| b = (beta - gamma) * alpha | ||
| return rho0 / (p["r_s"] ** 3) / uu**gamma / (1.0 + uu ** (1.0 / alpha)) ** b | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def mass_enclosed(p: gt.Params, r: gt.Sz0) -> gt.Sz0: | ||
| a, b, g = p["alpha"], p["beta"], p["gamma"] | ||
| _, chi = _r_to_u_chi(p, r) | ||
| rho0 = _rho0(p) | ||
| c0, _, q0 = _cpq(a, b, g) | ||
| return ( | ||
| 4.0 | ||
| * jnp.pi | ||
| * rho0 | ||
| * a | ||
| * (jsp.beta(c0 - q0, q0) * jsp.betainc(c0 - q0, q0, chi)) | ||
| ) | ||
|
|
||
|
|
||
| @ft.partial(jax.jit) | ||
| def potential(p: gt.Params, r: gt.Sz0, /) -> gt.Sz0: | ||
| r"""Spherical potential for double power-law Zhao model. | ||
|
|
||
| See Eq. 6 and 7 in Zhao (1996). | ||
|
|
||
| This function uses the variable z for what Zhao called :math:`\chi`. | ||
| """ | ||
| a, b, g = p["alpha"], p["beta"], p["gamma"] | ||
|
|
||
| uu, chi = _r_to_u_chi(p, r) | ||
|
|
||
| # Special case the Jaffe potential, where there is an "inf - inf" below | ||
| is_jaffe = (a == 1.0) & (b == 4.0) & (g == 2.0) | ||
|
|
||
| def Phi_jaffe() -> gt.Sz0: | ||
| # Note: the extra factor of 2 is because m is mass enclosed in r_s, not total | ||
| return -p["G"] * 2 * p["m"] / p["r_s"] * jnp.log1p(1.0 / uu) | ||
|
|
||
| rho0 = _rho0(p) | ||
| c0, p0, _ = _cpq(a, b, g) | ||
|
|
||
| # Left term in Eq. 7 | ||
| term_l = mass_enclosed(p, r) | ||
|
|
||
| # Right term in Eq. 7 | ||
| eps = jnp.sqrt(jnp.finfo(r.dtype).eps) | ||
| p0_safe = jnp.where(p0 <= 0, eps, p0) | ||
| logB = jsp.betaln(p0_safe, c0 - p0) | ||
| log1mI = jnp.log1p(-jsp.betainc(p0_safe, c0 - p0, chi)) | ||
| term_r = 4.0 * jnp.pi * rho0 * a / p["r_s"] * jnp.exp(logB + log1mI) | ||
|
|
||
| def Phi_general() -> gt.Sz0: | ||
| return -p["G"] * (term_l / r + term_r) | ||
|
|
||
| return jax.lax.cond(is_jaffe, Phi_jaffe, Phi_general) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import jax | ||
| import pytest | ||
|
|
||
| import quaxed.numpy as jnp | ||
| import unxt as u | ||
|
|
||
| import galax.potential as gp | ||
| from galax.potential._src.builtin.zhao import ZhaoPotential | ||
|
|
||
| check_funcs = ["potential", "gradient", "density", "hessian"] | ||
|
|
||
| # Settings of (alpha, beta, gamma) and correspondence to known models | ||
| abg_pot_finite = [ | ||
| ((1, 4, 1), gp.HernquistPotential), | ||
| ((1, 4, 2), gp.JaffePotential), | ||
| ((1 / 2, 5, 0), gp.PlummerPotential), | ||
| ] | ||
| abg_pot_infinite = [ | ||
| ((1, 3, 1), gp.NFWPotential), | ||
| ] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def xyz(): | ||
| test_r = jnp.geomspace(1e-3, 1e2, 128) | ||
| rand_uvecs = jax.random.normal(jax.random.key(42), shape=(test_r.size, 3)) | ||
| rand_uvecs = rand_uvecs / jnp.linalg.norm(rand_uvecs, axis=-1, keepdims=True) | ||
| return u.Quantity(test_r[:, None] * rand_uvecs, "kpc") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("func_name", check_funcs) | ||
| @pytest.mark.parametrize(("abg", "OtherPotential"), abg_pot_finite) | ||
| def test_zhao_against_finite_mass_correspondences(func_name, abg, OtherPotential, xyz): | ||
| m_tot = u.Quantity(1.3e11, "Msun") | ||
| zhao = ZhaoPotential.from_m_tot( | ||
| m_tot=m_tot, | ||
| r_s=u.Quantity(8.1, "kpc"), | ||
| alpha=abg[0], | ||
| beta=abg[1], | ||
| gamma=abg[2], | ||
| units="galactic", | ||
| ) | ||
| other = OtherPotential(m_tot=m_tot, r_s=zhao.parameters["r_s"], units="galactic") | ||
|
|
||
| zhao_result = getattr(zhao, func_name)(xyz, u.Quantity(0.0, "Myr")) | ||
| other_result = getattr(other, func_name)(xyz, u.Quantity(0.0, "Myr")) | ||
|
|
||
| assert jnp.allclose( | ||
| zhao_result, other_result, rtol=1e-8, atol=u.Quantity(1e-6, zhao_result.unit) | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can use
u.ustrip(AllowValue