diff --git a/pyproject.toml b/pyproject.toml index 093590ed..2b37ff38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,9 @@ "wadler_lindig>=0.1.6", "xmmutablemap>=0.1", "zeroth>=1.0", + "spexial @ git+https://github.com/JAXtronomy/spexial.git@main", + "hypothesis>=6.135.14", + "gala>=1.9.1", ] [project.optional-dependencies] diff --git a/src/galax/potential/_src/params/base.py b/src/galax/potential/_src/params/base.py index 4b289207..0c6ec6c9 100644 --- a/src/galax/potential/_src/params/base.py +++ b/src/galax/potential/_src/params/base.py @@ -26,7 +26,7 @@ def __call__( Parameters ---------- - t : `~galax.typing.BBtQuSz0` + t : `~galax._custom_types.BBtQuSz0` Time(s) at which to compute the parameter value. ustrip : Unit | None Unit to strip from the parameter value. @@ -62,7 +62,7 @@ def __call__( Parameters ---------- - t : `~galax.typing.BBtQuSz0` + t : `~galax._custom_types.BBtQuSz0` The time(s) at which to compute the parameter value. ustrip : Unit | None The unit to strip from the parameter value. If None, the diff --git a/src/galax/potential/_src/params/constant.py b/src/galax/potential/_src/params/constant.py index 630e2279..71d6b311 100644 --- a/src/galax/potential/_src/params/constant.py +++ b/src/galax/potential/_src/params/constant.py @@ -135,7 +135,7 @@ def __call__( Parameters ---------- - t : `~galax.typing.BBtQuSz0`, optional + t : `~galax._custom_types.BBtQuSz0`, optional This is ignored and is thus optional. Note that for most :class:`~galax.potential.AbstractParameter` the time is required. ustrip : Unit | None diff --git a/src/galax/potential/_src/scf/__init__.py b/src/galax/potential/_src/scf/__init__.py new file mode 100644 index 00000000..a62b0ea6 --- /dev/null +++ b/src/galax/potential/_src/scf/__init__.py @@ -0,0 +1,11 @@ +from . import bfe, bfe_helper, coeffs, coeffs_helper +from .bfe import * +from .bfe_helper import * +from .coeffs import * +from .coeffs_helper import * + +__all__: list[str] = [] +__all__ += bfe.__all__ +__all__ += bfe_helper.__all__ +__all__ += coeffs.__all__ +__all__ += coeffs_helper.__all__ diff --git a/src/galax/potential/_src/scf/bfe.py b/src/galax/potential/_src/scf/bfe.py new file mode 100644 index 00000000..16dc3740 --- /dev/null +++ b/src/galax/potential/_src/scf/bfe.py @@ -0,0 +1,189 @@ +"""Self-Consistent Field Potential.""" + +__all__ = ["SCFPotential", "STnlmSnapshotParameter"] + +from collections.abc import Callable +from functools import partial +from typing import Any + +import equinox as eqx +import jax +import jax.numpy as jnp +from jaxtyping import Array, Float + +import galax._custom_types as gt +from .bfe_helper import phi_nl_vec, rho_nl as calculate_rho_nl +from .coeffs import compute_coeffs_discrete +from .utils import cartesian_to_spherical, real_Ylm +from galax.potential import AbstractPotential +from galax.potential._src.params.base import AbstractParameter +from galax.potential._src.params.field import ParameterField + +############################################################################## + + +class SCFPotential(AbstractPotential): + r"""Self-Consistent Field (SCF) potential. + + A gravitational potential represented as a basis function expansion. This + uses the self-consistent field (SCF) method of Hernquist & Ostriker (1992) + and Lowing et al. (2011), and represents all coefficients as real + quantities. + + Parameters + ---------- + m : numeric + Scale mass. + r_s : numeric + Scale length. + Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable + Array of coefficients for the cos() terms of the expansion. This should + be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the + number of radial expansion terms and `lmax` is the number of spherical + harmonic `l` terms. If a callable is provided, it should accept a + single argument `t` and return the array of coefficients for that time. + Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable + Array of coefficients for the sin() terms of the expansion. This should + be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the + number of radial expansion terms and `lmax` is the number of spherical + harmonic `l` terms. If a callable is provided, it should accept a + single argument `t` and return the array of coefficients for that time. + units : iterable + Unique list of non-reducable units that specify (at minimum) the length, + mass, time, and angle units. + """ + + m: AbstractParameter = ParameterField(dimensions="mass") + r_s: AbstractParameter = ParameterField(dimensions="length") + Snlm: AbstractParameter = ParameterField(dimensions="dimensionless") + Tnlm: AbstractParameter = ParameterField(dimensions="dimensionless") + + nmax: int = eqx.field(init=False, static=True, repr=False) + lmax: int = eqx.field(init=False, static=True, repr=False) + + def __post_init__(self) -> None: + super().__post_init__() + + # shape parameters + shape = self.Snlm(0).shape + object.__setattr__(self, "nmax", shape[0] - 1) + object.__setattr__(self, "lmax", shape[1] - 1) + + # ========================================================================== + + @partial(jax.jit, inline=True) + def _potential( + self, xyz: gt.BtQuSz3, t: gt.BtQuSz0, / + ) -> gt.SzN | gt.FloatSz0: + r_s = self.r_s(t) + r, theta, phi = cartesian_to_spherical(xyz).T + + s = jnp.atleast_1d(r / r_s) # ([n],[l],[m],[N]) + theta = jnp.atleast_1d(theta)[None, None, None] # ([n],[l],[m],[N]) + phi = jnp.atleast_1d(phi)[None, None, None] # ([n],[l],[m],[N]) + + ns = jnp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m]) + ls = jnp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m]) + phi_nl = phi_nl_vec(s, ns, ls) # (n, l, [m], N) + + li, mi = jnp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) + shape = (1, self.lmax + 1, self.lmax + 1, 1) # ([n], l, m, [N]) + midx = jnp.zeros(shape, dtype=int).at[:, li, mi, 0].set(mi) # ([n], l, m, [N]) + + Ylm = jnp.zeros(shape[:-1] + (len(s),)) + Ylm = Ylm.at[0, li, mi, :].set( + real_Ylm(theta[:, 0, 0, :], li[..., None], mi[..., None]) + ) + + Snlm = self.Snlm(t, r_s=r_s)[..., None] + Tnlm = self.Tnlm(t, r_s=r_s)[..., None] + + out = (self._G * self.m(t) / r_s) * jnp.sum( + Ylm * phi_nl * (Snlm * jnp.cos(midx * phi) + Tnlm * jnp.sin(midx * phi)), + axis=(0, 1, 2), + ) + return out[0] if len(xyz.shape) == 1 else out + + @partial(jax.jit, inline=True) + @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1 + def _density(self, q: gt.QuSz3, /, t: gt.QuSz0) -> Float[Array, "N"]: # type: ignore[name-defined] + """Compute the density at the given position(s).""" + r, theta, phi = cartesian_to_spherical(q) + r_s = self.r_s(t) + s = jnp.atleast_1d(r / r_s)[:, None, None, None] + theta = jnp.atleast_1d(theta)[:, None, None, None] + phi = jnp.atleast_1d(phi)[:, None, None, None] + + ns = jnp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m]) + ls = jnp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m]) + + phi_nl = calculate_rho_nl(s, ns[None], ls[None]) + + li, mi = jnp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) + shape = (1, 1, self.lmax + 1, self.lmax + 1) + midx = jnp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi) + Ylm = jnp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1)) + Ylm = Ylm.at[:, li, mi, :].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0])) + + Snlm = self.Snlm(t, r_s=r_s)[None] + Tnlm = self.Tnlm(t, r_s=r_s)[None] + + out = (self._G * self.m(t) / r_s) * jnp.sum( + Ylm * phi_nl * (Snlm * jnp.cos(midx * phi) + Tnlm * jnp.sin(midx * phi)), + axis=(1, 2, 3), + ) + return out[0] if len(q.shape) == 1 else out + + +# ============================================================================= + + +class STnlmSnapshotParameter(AbstractParameter): # type: ignore[misc] + """Parameter for the STnlm coefficients.""" + + snapshot: Callable[ # type: ignore[name-defined] + [Float[Array, "N"]], + tuple[Float[Array, "3 N"], Float[Array, "N"]], + ] + """Cartesian coordinates of the snapshot. + + This should be a callable that accepts a single argument `t` and returns + the cartesian coordinates and the masses of the snapshot at that time. + """ + + nmax: int = eqx.field(static=True, converter=int) + """Radial expansion term.""" + + lmax: int = eqx.field(static=True, converter=int) + """Spherical harmonic term.""" + + def __call__( + self, t: gt.QuSz0, *, r_s: gt.QuSz0, **_: Any + ) -> tuple[ + Float[Array, "{self.nmax}+1 {self.lmax}+1 {self.lmax}+1"], + Float[Array, "{self.nmax}+1 {self.lmax}+1 {self.lmax}+1"], + ]: + """Return the coefficients at the given time(s). + + TODO: are the types correct here? Should they be quantity specific? + Parameters + ---------- + t : float | Array[float, ()] + Time at which to evaluate the coefficients. + r_s : float | Array[float, ()] + Scale length of the potential at the given time(s. + **kwargs : Any + Additional keyword arguments are ignored. + + Returns + ------- + Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the cosine expansion coefficient. + Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the sine expansion coefficient. + """ + xyz, m = self.snapshot(t) + coeffs: tuple[Array, Array] = compute_coeffs_discrete( + xyz, m, nmax=self.nmax, lmax=self.lmax, r_s=r_s + ) + return coeffs diff --git a/src/galax/potential/_src/scf/bfe_helper.py b/src/galax/potential/_src/scf/bfe_helper.py new file mode 100644 index 00000000..cf5ebf2a --- /dev/null +++ b/src/galax/potential/_src/scf/bfe_helper.py @@ -0,0 +1,81 @@ +"""Self-Consistent Field Potential.""" + +__all__: list[str] = [] + +from functools import partial + +import jax +import jax.numpy as jnp +from jaxtyping import Array, Float + +from .coeffs_helper import normalization_Knl +from .utils import psi_of_r +from spexial import eval_gegenbauers +import galax._custom_types as gt + + +def rho_nl(n: gt.IntSz0, l: gt.IntSz0, s: gt.FloatSz0, +) -> gt.FloatSz0: + r"""Radial density expansion terms. + + Parameters + ---------- + s : Array[float, (n,)] + Scaled radius :math:`r/r_s`. + n : int + Radial expansion term. + l : int + Spherical harmonic term. + + Returns + ------- + Array[float, (n,)] + """ + return ( + jnp.sqrt(4 * jnp.pi) + * (normalization_Knl(n=n, l=l) / (2 * jnp.pi)) + * (s**l / (s * (1 + s) ** (2 * l + 3))) + * eval_gegenbauers(n, 2 * l + 1.5, psi_of_r(s)) + ) + +rho_nl_jit_vec = jax.jit( + jax.vmap( jax.vmap(rho_nl, in_axes=(None, 0, None),), in_axes=(None, None, 0)), static_argnames="n" +) + +# ====================================================================== + + +def phi_nl(n: gt.IntSz0, l: gt.IntSz0, s: gt.FloatSz0, +) -> gt.FloatSz0: + r"""Angular density expansion terms. + + Parameters + ---------- + n : int + Max Radial expansion term. + l : int + Spherical harmonic term. + s : Float + Scaled radius :math:`r/r_s`. + + Returns + ------- + Array[float, (n + 1,)] + + Examples + -------- + >>> import jax.numpy as jnp + >>> phi_nl(0.5, 1, 1) + Array(0.5, dtype=float32) + >>> phi_nl(jnp.array([0.5, 0.5]), 1, 1) + Array([0.5, 0.5], dtype=float32) + """ + return ( + -jnp.sqrt(4 * jnp.pi) + * (s**l / (1.0 + s) ** (2 * l + 1)) + * eval_gegenbauers(n, 2 * l + 1.5, psi_of_r(s)) + ) + +phi_nl_jit_vec = jax.jit( + jax.vmap( jax.vmap(phi_nl, in_axes=(None, 0, None),), in_axes=(None, None, 0)), static_argnames="n" +) \ No newline at end of file diff --git a/src/galax/potential/_src/scf/coeffs.py b/src/galax/potential/_src/scf/coeffs.py new file mode 100644 index 00000000..5d354cd5 --- /dev/null +++ b/src/galax/potential/_src/scf/coeffs.py @@ -0,0 +1,94 @@ +"""Self-Consistent Field Potential.""" + +__all__ = ["compute_coeffs_discrete"] + + +from functools import partial + +import jax +import jax.numpy as jnp +from jaxtyping import Array, Float + +import galax._custom_types as gt +from .bfe_helper import phi_nl_vec +from .coeffs_helper import expansion_coeffs_Anl_discrete +from .utils import cartesian_to_spherical, real_Ylm + + +@partial(jax.jit, static_argnames=("nmax", "lmax")) +def compute_coeffs_discrete( + xyz: Float[Array, "samples 3"], + mass: Float[Array, "samples"], # type: ignore[name-defined] + *, + nmax: gt.IntSz0, + lmax: gt.IntSz0, + r_s: gt.FloatQuSz0, +) -> tuple[ + Float[Array, "{nmax}+1 {lmax}+1 {lmax}+1"], + Float[Array, "{nmax}+1 {lmax}+1 {lmax}+1"], +]: + """Compute expansion coefficients for the SCF potential. + + Compute the expansion coefficients for representing the density distribution + of input points as a basis function expansion. The points, ``xyz``, are + assumed to be samples from the density distribution. + + This is Equation 15 of Lowing et al. (2011). + + Parameters + ---------- + xyz : Array[float, (n_samples, 3)] + Samples from the density distribution. + :todo:`unit support` + mass : Array[float, (n_samples,)] + Mass of each sample. + :todo:`unit support` + nmax : int + Maximum value of ``n`` for the radial expansion. + lmax : int + Maximum value of ``l`` for the spherical harmonics. + r_s : numeric + Scale radius. + :todo:`unit support` + + Returns + ------- + Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the cosine expansion coefficient. + Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the sine expansion coefficient. + """ + + rthetaphi = cartesian_to_spherical(xyz) + r = rthetaphi[..., 0] + theta = rthetaphi[..., 1] + phi = rthetaphi[..., 2] + s = r / r_s + + ns = jnp.arange(nmax + 1)[:, None] # (n, l) + ls = jnp.arange(lmax + 1)[None, :] # (n, l) + + Anl_til = expansion_coeffs_Anl_discrete(ns, ls) # (n, l) + phinl = phi_nl_vec(s, ns, ls) # (n, l, N) + + li, mi = jnp.tril_indices(lmax + 1) # (l*(l+1)//2,) + lm = jnp.zeros((lmax + 1, lmax + 1), dtype=int).at[li, mi].set(li) # (l, m) + ms = jnp.zeros((lmax + 1, lmax + 1), dtype=int).at[li, mi].set(mi) # (l, m) + # TODO: this is VERY SLOW. Can we do better? + Ylm = real_Ylm(theta, lm, ms, m_max=100) # (l, m, N) + + delta = jax.lax.select(ms == 0, jnp.ones_like(ms), jnp.zeros_like(ms)) # (l, m) + mvalid = jnp.zeros((lmax + 1, lmax + 1)).at[li, mi].set(1) # select m <= l + + tmp = ( # (n, l, m, N) using broadcasting + mvalid[None, :, :, None] + * (2 - delta[None, :, :, None]) + * Anl_til[:, :, None, None] + * mass[None, None, None, :] + * phinl[:, :, None, :] + * Ylm[None, :, :, :] + ) + Snlm = jnp.sum(tmp * jnp.cos(ms[None, :, :, None] * phi[None, None, None]), axis=-1) + Tnlm = jnp.sum(tmp * jnp.sin(ms[None, :, :, None] * phi[None, None, None]), axis=-1) + + return Snlm, Tnlm diff --git a/src/galax/potential/_src/scf/coeffs_helper.py b/src/galax/potential/_src/scf/coeffs_helper.py new file mode 100644 index 00000000..a073255b --- /dev/null +++ b/src/galax/potential/_src/scf/coeffs_helper.py @@ -0,0 +1,61 @@ +"""Self-Consistent Field Potential.""" + +__all__: list[str] = [] + +from typing import overload + +from jax.scipy.special import gamma, factorial +from jaxtyping import Array, Float, Integer + +import quaxed.numpy as jnp + +@overload +def normalization_Knl(n: int, l: int) -> float: ... + + +@overload +def normalization_Knl(n: Array, l: Array) -> Array: ... + + +def normalization_Knl( + n: Integer[Array, "*#shape"] | int, l: Integer[Array, "*#shape"] | int +) -> Float[Array, "*shape"] | float: + """SCF normalization factor. + + Parameters + ---------- + n : int + Radial expansion term. + l : int + Spherical harmonic term. + + Returns + ------- + float + """ + return 0.5 * n * (n + 4 * l + 3.0) + (l + 1) * (2 * l + 1) + + +def expansion_coeffs_Anl_discrete( + n: Integer[Array, "*#shape"], l: Integer[Array, "*#shape"] +) -> Float[Array, "*shape"]: + """Return normalization factor for the coefficients. + + Equation 16 of Lowing et al. (2011). + + Parameters + ---------- + n : int + Radial expansion term. + l : int + spherical harmonic term. + + Returns + ------- + float + """ + Knl = normalization_Knl(n=n, l=l) + prefac = -(2 ** (8.0 * l + 6)) / (4 * jnp.pi * Knl) + numerator = factorial(n) * (n + 2 * l + 1.5) * gamma(2 * l + 1.5) ** 2 + denominator = gamma(n + 4.0 * l + 3.0) + return prefac * (numerator / denominator) diff --git a/src/galax/potential/_src/scf/utils.py b/src/galax/potential/_src/scf/utils.py new file mode 100644 index 00000000..3c40a541 --- /dev/null +++ b/src/galax/potential/_src/scf/utils.py @@ -0,0 +1,83 @@ +"""Utility Functions.""" + +from functools import partial +from typing import TypeAlias, TypeVar, cast + +import jax +from jaxtyping import ArrayLike, Shaped + +import quaxed.numpy as jnp + +import galax._custom_types as gt +from galax.potential._src.builtin.multipole import compute_Ylm + +BatchableIntSz0: TypeAlias = Shaped[gt.IntSz0, "*#batch"] + +T = TypeVar("T", bound=ArrayLike) + + +@partial(jax.jit) +@partial(jax.numpy.vectorize, signature="(3)->(3)") +def cartesian_to_spherical(xyz: gt.FloatSz3, /) -> gt.FloatSz3: + """Convert Cartesian coordinates to spherical coordinates. + + Parameters + ---------- + xyz : Array[float, (3,)] + Cartesian coordinates in the form (x, y, z). + + Returns + ------- + r_theta_phi : Array[float, (3,)] + Spherical radius. + Inclination (polar) angle in [0, pi] from North to South pole. + Azimuthal angle in [-pi, pi] + """ + r = jnp.sqrt(jnp.sum(xyz**2, axis=0)) # spherical radius + # TODO: this is a hack to avoid the ambiguity at r==0. This should be done better. + theta = jax.lax.select( + r == 0, jnp.zeros_like(r), jnp.arccos(xyz[2] / r) + ) # inclination angle + phi = jnp.arctan2(xyz[1], xyz[0]) # azimuthal angle + return jnp.array([r, theta, phi]) + +def psi_of_r(r: T) -> T: + r""":math:`\psi(r) = (r-1)/(r+1)`. + + Equation 9 of Lowing et al. (2011). + """ + return cast("T", (r - 1.0) / (r + 1.0)) + + +# ============================================================================= + + +@partial(jax.jit, static_argnames=("m_max",)) +def real_Ylm( + theta: gt.SzAny, l: BatchableIntSz0, m: BatchableIntSz0, m_max: int = 100 +) -> gt.SzAny: + r"""Get the spherical harmonic :math:`Y_{lm}(\theta)` of the polar angle. + + This is different than the scipy (and thus JAX) convention, which is + :math:`Y_{lm}(\theta, \phi)`. + Note that scipy also uses the opposite convention for theta, phi where + theta is the azimuthal angle and phi is the polar angle. + + Parameters + ---------- + theta : Array[float, (*shape, N)] + Polar angle in [0, pi]. + l, m : int | Array[int, ()] + Spherical harmonic terms. l in [0,lmax], m in [0,l]. + m_max : int, optional + Maximum order of the spherical harmonic expansion. + + Returns + ------- + Array[float, (l, m, N)] + Spherical harmonic. The shape is batched using the vectorization signature + ``(n),(),()->(n)``. + """ + # TODO: raise an error if m > m_max + _real_Ylm, _ = compute_Ylm(l, m, theta=theta, phi=jnp.zeros_like(theta), l_max=m_max) + return _real_Ylm diff --git a/tests/potential/scf/test_coeff_helper.py b/tests/potential/scf/test_coeff_helper.py new file mode 100644 index 00000000..ef496a4e --- /dev/null +++ b/tests/potential/scf/test_coeff_helper.py @@ -0,0 +1,73 @@ +"""Test Gegenbauer utils.""" + +import jax +import numpy as np +import pytest + +import quaxed.numpy as jnp + +from galax.potential._src.scf.bfe_helper import rho_nl +from galax.potential._src.scf.coeffs_helper import ( + expansion_coeffs_Anl_discrete, + normalization_Knl, +) + +def test_normalization_Knl(): + """Test the ``normalization_Knl`` function. + + .. todo:: + + This test is not very good. It should be improved. + """ + assert normalization_Knl(0, 0) == 1 + assert normalization_Knl(1, 0) == 3 + assert normalization_Knl(2, 0) == 6 + assert normalization_Knl(0, 1) == 6 + assert normalization_Knl(0, 2) == 15 + assert normalization_Knl(1, 1) == 10 + + +# ============================================================================= + + +def test_expansion_coeffs_Anl_discrete(): + """Test the ``expansion_coeffs_Anl_discrete`` function. + + .. todo:: + + This test is not very good. It should be improved. + """ + np.testing.assert_allclose(expansion_coeffs_Anl_discrete(0, 0), -3) + np.testing.assert_allclose( + expansion_coeffs_Anl_discrete(1, 0), -0.555555, rtol=1e-5 + ) + + +# ============================================================================= + + +@jax.jit +def compare_rho_nl(s, n, l): + """Compare the ``rho_nl`` function.""" + + mock = rho_nl(s, n, l) + observed = jax.lax.stop_gradient(mock) + + return -jnp.sum((observed - mock) ** 2) + + +@pytest.mark.skip(reason="TODO") +def test_rho_nl(): + """Test the ``rho_nl`` function.""" + s = jnp.linspace(0, 4, 100, dtype=float) + n = jnp.array([1.0]) + l = jnp.array([2.0]) + + first_deriv = jax.jacfwd(compare_rho_nl)(s, n, l) + assert first_deriv == 0 + + +@pytest.mark.skip(reason="TODO") +def test_phi_nl(): + """Test the ``phi_nl`` function.""" + raise NotImplementedError diff --git a/tests/potential/scf/test_coeffs.py b/tests/potential/scf/test_coeffs.py new file mode 100644 index 00000000..3f0f82b6 --- /dev/null +++ b/tests/potential/scf/test_coeffs.py @@ -0,0 +1,42 @@ +"""Test the Coefficient Calculations.""" + +import gala.potential as gp +import numpy as np + +import quaxed.numpy as jnp + +import galax.potential._src as gpx + + +def test_compute_coeffs_discrete(): + """Test the ``normalization_Knl`` function. + + .. todo:: + + This test is not very good. It should be improved. + """ + # Setup + rng = np.random.default_rng(42) + particle_xyz = rng.normal(0.0, 5.0, size=(3, 10_000)) + particle_xyz[2] = np.abs(particle_xyz[2]) + particle_xyz = jnp.array(particle_xyz) + + particle_mass = jnp.ones(particle_xyz.shape[1]) + particle_mass = 1e12 * particle_mass / particle_mass.sum() + + nmax = 2 + lmax = 3 + r_s = 10 + + # Gala + gala_Snlm, gala_Tnlm = gp.scf.compute_coeffs_discrete( + np.array(particle_xyz), np.array(particle_mass), nmax=nmax, lmax=lmax, r_s=r_s + ) + + # Galdynamix + Snlm, Tnlm = gpx.scf.compute_coeffs_discrete( + particle_xyz, particle_mass, nmax=nmax, lmax=lmax, r_s=r_s + ) + + np.testing.assert_allclose(Snlm, gala_Snlm, rtol=1e-7) + np.testing.assert_allclose(Tnlm, gala_Tnlm, rtol=1e-7) diff --git a/tests/potential/scf/test_utils.py b/tests/potential/scf/test_utils.py new file mode 100644 index 00000000..5c19768e --- /dev/null +++ b/tests/potential/scf/test_utils.py @@ -0,0 +1,129 @@ +"""Test :mod:`galax.potential._src.scf.utils`.""" + +import hypothesis +import hypothesis.extra.numpy as hnp +import jax +import numpy as np +import scipy.special as sp +from hypothesis import assume, given, strategies as st + +import quaxed.numpy as jnp + +from galax.potential._src.scf.utils import ( + cartesian_to_spherical, + psi_of_r, + real_Ylm, +) + + +# TODO: use hnp.floating_dtypes() +# TODO: test more batch dimensions +def xyz_strategy() -> st.SearchStrategy[np.ndarray]: + return hnp.arrays( + dtype=np.float64, + shape=st.tuples(st.integers(1, 100), st.integers(3, 3)), + elements=st.floats(-10, 10, allow_subnormal=False, allow_nan=False), + ) + + +@given(xyz_strategy()) +def test_cartesian_to_spherical(xyz): + """Test the ``cartesian_to_spherical`` function.""" + assume(np.all(xyz.sum(axis=1) != 0)) + + n = len(xyz) + xyz = jnp.asarray(xyz) + + rthetaphi = cartesian_to_spherical(xyz) + r = rthetaphi[..., 0] + theta = rthetaphi[..., 1] + phi = rthetaphi[..., 2] + + # Check + assert r.shape == (n,) + assert theta.shape == (n,) + assert phi.shape == (n,) + + assert jnp.all(r >= 0) + assert jnp.all(theta >= 0) & jnp.all(theta <= jnp.pi) + assert jnp.all(phi >= -jnp.pi) & jnp.all(phi <= jnp.pi) + + +def test_cartesian_to_spherical_jac(): + """Test the ``cartesian_to_spherical`` function.""" + # Scalar + xyz = jnp.asarray([1, 0, 0], dtype=float) + assert xyz.shape == (3,) + + output = jax.jacfwd(cartesian_to_spherical)(xyz) + np.testing.assert_array_equal( + output, [[1.0, 0.0, 0.0], [-0.0, -0.0, -1.0], [0.0, 1.0, 0.0]] + ) + + # Vector + xyz = jnp.asarray([[1, 0, 0], [0, 1, 0]], dtype=float) + assert xyz.shape == (2, 3) + + output = jax.jacfwd(cartesian_to_spherical)(xyz) + assert output.shape == (2, 3, 2, 3) # WTF? + np.testing.assert_array_equal( + output[0, :, 0, :], [[1.0, 0.0, 0.0], [-0.0, -0.0, -1.0], [0.0, 1.0, 0.0]] + ) + np.testing.assert_array_equal( + output[1, :, 1, :], [[0.0, 1.0, 0.0], [0.0, 0.0, -1.0], [-1.0, 0.0, 0.0]] + ) + + +# ============================================================================= + + +@given( + r=st.floats(0, 100) + | hnp.arrays(dtype=float, shape=hnp.array_shapes(), elements=st.floats(0, 100)), +) +def test_psi_of_r(r): + """Test the ``psi_of_r`` function.""" + got = psi_of_r(r) + expected = (r - 1) / (r + 1) + np.testing.assert_allclose(got, expected) + + +# ============================================================================= + + +def test_Ylm_jitting(): + """Test the ``real_Ylm`` function.""" + got = real_Ylm(5, 0, np.pi) + expected = np.real(sp.sph_harm(0, 5, 0, np.pi)) + np.testing.assert_allclose(got, expected) + + +@hypothesis.settings(deadline=500) +@given(l=st.integers(1, 25), m=st.integers(1, 25), theta=st.floats(0, np.pi)) +def test_real_Ylm(l, m, theta): + """Test the ``real_Ylm`` function.""" + assume(theta != 0) + assume(m <= l) + got = real_Ylm(l, m, theta) + expected = np.real(sp.sph_harm(m, l, 0, theta)) + np.testing.assert_allclose(got, expected) + + +# # TODO: mark as slow +# # TODO: test batch dimensions of l, m, theta +# @hypothesis.settings(deadline=500) +# @given( +# l=st.integers(1, 25), +# m=st.integers(1, 25), +# theta=hnp.arrays( +# dtype=np.float64, +# shape=st.integers(1, 100), +# elements=st.floats(1e-5, np.pi, allow_subnormal=False, allow_nan=False), +# ), +# ) +# def test_real_Ylm_vec(l, m, theta): +# """Test the ``real_Ylm`` function.""" +# assume(m <= l) +# got = real_Ylm(l, m, theta) +# expected = np.real(sp.sph_harm(m, l, 0, theta)) +# np.testing.assert_allclose(got, expected) diff --git a/uv.lock b/uv.lock index 4d7cc850..1450454f 100644 --- a/uv.lock +++ b/uv.lock @@ -2481,4 +2481,4 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/4a/ec/df218ada17364453e8b9c18e7c6095ce76baad8f6700eb0a666cd8ef5c00/zeroth-1.0.0.tar.gz", hash = "sha256:6a6f9e517f9e473264e4ec9a91026f7a18f346f237e27ccfccb9a90c3811948c", size = 9604 } wheels = [ { url = "https://files.pythonhosted.org/packages/86/5c/4bf0c19637217a1e925f1f5a0c390046d9c1926572fbd1ef8cbc6d1d4919/zeroth-1.0.0-py3-none-any.whl", hash = "sha256:4ba39e99954848a63c200d4fe2c2b65707609bbb82ed62a1a5d61880bfa0a502", size = 4236 }, -] +] \ No newline at end of file