Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/galax/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"AbstractSolver",
"DynamicsSolver",
# mockstream
"StreamSimulator",
"MockStreamArm",
"MockStream",
"MockStreamGenerator",
"MockStreamGenerator", # TODO: deprecate
# mockstream.df
"AbstractStreamDF",
"FardalStreamDF",
Expand Down Expand Up @@ -47,6 +48,7 @@
MockStream,
MockStreamArm,
MockStreamGenerator,
StreamSimulator,
)
from .solve import AbstractSolver, DynamicsSolver

Expand Down
10 changes: 10 additions & 0 deletions src/galax/dynamics/_src/mockstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
"""

__all__ = [
"StreamSimulator",
# Coordinates
"MockStream",
"MockStreamArm",
# Phase-Space Distribution
"AbstractStreamDF",
"Fardal15StreamDF",
"Chen24StreamDF",
]

from .arm import MockStreamArm
from .core import MockStream
from .df_base import AbstractStreamDF
from .df_chen24 import Chen24StreamDF
from .df_fardal15 import Fardal15StreamDF
from .simulate import StreamSimulator
37 changes: 37 additions & 0 deletions src/galax/dynamics/_src/mockstream/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from typing import Any, ClassVar, Protocol, cast, final, runtime_checkable

import diffrax as dfx
import equinox as eqx
from plum import dispatch

import coordinax as cx
import quaxed.numpy as jnp
import unxt as u
from unxt.quantity import BareQuantity

import galax._custom_types as gt
import galax.coordinates as gc
Expand Down Expand Up @@ -67,6 +69,41 @@ def _shape_tuple(self) -> tuple[gt.Shape, gc.ComponentShapeTuple]:

#####################################################################


@gc.AbstractPhaseSpaceObject.from_.dispatch # type: ignore[attr-defined,misc]
def from_(
cls: type[MockStreamArm],
soln: dfx.Solution,
/,
*,
release_time: gt.BBtQuSz0,
frame: cx.frames.AbstractReferenceFrame,
units: u.AbstractUnitSystem, # not dispatched on, but required
unbatch_time: bool = True,
) -> MockStreamArm:
"""Create a new instance of the class."""
# Reshape (*tbatch, T, *ybatch) to (*tbatch, *ybatch, T)
t = soln.ts # already in the shape (*tbatch, T)
n_tbatch = soln.t0.ndim
q = jnp.moveaxis(soln.ys[0], n_tbatch, -2)
p = jnp.moveaxis(soln.ys[1], n_tbatch, -2)

# Reshape (*tbatch, *ybatch, T) to (*tbatch, *ybatch) if T == 1
if unbatch_time and t.shape[-1] == 1:
t = t[..., -1]
q = q[..., -1, :]
p = p[..., -1, :]

# Convert the solution to a phase-space position
return cls(
q=cx.CartesianPos3D.from_(q, units["length"]),
p=cx.CartesianVel3D.from_(p, units["speed"]),
t=BareQuantity(t, units["time"]),
release_time=release_time,
frame=frame,
)


# =========================================================
# `__getitem__`

Expand Down
55 changes: 55 additions & 0 deletions src/galax/dynamics/_src/mockstream/df_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Stream Distribution Functions for ejecting mock stream particles."""

__all__ = ["AbstractStreamDF"]

import abc
from typing import TypeAlias

import equinox as eqx
from jaxtyping import PRNGKeyArray

import galax._custom_types as gt
import galax.potential as gp

Carry: TypeAlias = tuple[gt.QuSz3, gt.QuSz3, gt.QuSz3, gt.QuSz3]


class AbstractStreamDF(eqx.Module, strict=True): # type: ignore[call-arg, misc]
"""Abstract base class of Stream Distribution Functions."""

# TODO: keep units and PSP through this func
@abc.abstractmethod
def sample(
self,
key: PRNGKeyArray,
potential: gp.AbstractPotential,
x: gt.BBtQuSz3,
v: gt.BBtQuSz3,
prog_mass: gt.BBtFloatQuSz0,
t: gt.BBtFloatQuSz0,
) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]:
"""Generate stream particle initial conditions.

Parameters
----------
rng : :class:`jaxtyping.PRNGKeyArray`
Pseudo-random number generator.
potential : :class:`galax.potential.AbstractPotential`
The potential of the host galaxy.
x : Quantity[float, (*#batch, 3), "length"]
3d position (x, y, z)
v : Quantity[float, (*#batch, 3), "speed"]
3d velocity (v_x, v_y, v_z)
prog_mass : Quantity[float, (*#batch), "mass"]
Mass of the progenitor.
t : Quantity[float, (*#batch), "time"]
The release time of the stream particles.

Returns
-------
x_lead, v_lead: Quantity[float, (*batch, 3), "length" | "speed"]
Position and velocity of the leading arm.
x_trail, v_trail : Quantity[float, (*batch, 3), "length" | "speed"]
Position and velocity of the trailing arm.
"""
...
136 changes: 136 additions & 0 deletions src/galax/dynamics/_src/mockstream/df_chen24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["Chen24StreamDF"]


import warnings
from functools import partial
from typing import final

import jax
import jax.random as jr
from jaxtyping import PRNGKeyArray

import coordinax as cx
import quaxed.numpy as jnp

import galax._custom_types as gt
import galax.potential as gp
from .df_base import AbstractStreamDF
from galax.dynamics._src.cluster.radius import tidal_radius
from galax.dynamics._src.register_api import specific_angular_momentum

# ============================================================
# Constants

mean = jnp.array([1.6, -30, 0, 1, 20, 0])

cov = jnp.array(
[
[0.1225, 0, 0, 0, -4.9, 0],
[0, 529, 0, 0, 0, 0],
[0, 0, 144, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[-4.9, 0, 0, 0, 400, 0],
[0, 0, 0, 0, 0, 484],
]
)

# ============================================================


@final
class Chen24StreamDF(AbstractStreamDF):
"""Chen Stream Distribution Function.

A class for representing the Chen+2024 distribution function for
generating stellar streams based on Chen et al. 2024
https://ui.adsabs.harvard.edu/abs/2024arXiv240801496C/abstract
"""

def __init__(self) -> None:
super().__init__()
warnings.warn(
'Currently only the "no progenitor" version '
"of the Chen+24 model is supported!",
RuntimeWarning,
stacklevel=1,
)

@partial(jax.jit, inline=True)
def sample(
self,
key: PRNGKeyArray,
potential: gp.AbstractPotential,
x: gt.BBtQuSz3,
v: gt.BBtQuSz3,
prog_mass: gt.BBtFloatQuSz0,
t: gt.BBtFloatQuSz0,
) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]:
"""Generate stream particle initial conditions."""
# Random number generation

# x_new-hat
r = jnp.linalg.vector_norm(x, axis=-1, keepdims=True)
x_new_hat = x / r

# z_new-hat
L_vec = specific_angular_momentum(x, v)
z_new_hat = cx.vecs.normalize_vector(L_vec)

# y_new-hat
phi_vec = v - jnp.sum(v * x_new_hat, axis=-1, keepdims=True) * x_new_hat
y_new_hat = cx.vecs.normalize_vector(phi_vec)

r_tidal = tidal_radius(potential, x, v, mass=prog_mass, t=t)

# Bill Chen: method="cholesky" doesn't work here!
posvel = jr.multivariate_normal(
key, mean, cov, shape=r_tidal.shape, method="svd"
)

Dr = posvel[:, 0] * r_tidal

v_esc = jnp.sqrt(2 * potential.constants["G"] * prog_mass / Dr)
Dv = posvel[:, 3] * v_esc

# convert degrees to radians
phi = posvel[:, 1] * 0.017453292519943295
theta = posvel[:, 2] * 0.017453292519943295
alpha = posvel[:, 4] * 0.017453292519943295
beta = posvel[:, 5] * 0.017453292519943295

ctheta, stheta = jnp.cos(theta), jnp.sin(theta)
cphi, sphi = jnp.cos(phi), jnp.sin(phi)
calpha, salpha = jnp.cos(alpha), jnp.sin(alpha)
cbeta, sbeta = jnp.cos(beta), jnp.sin(beta)

# Trailing arm
x_trail = (
x
+ (Dr * ctheta * cphi)[:, None] * x_new_hat
+ (Dr * ctheta * sphi)[:, None] * y_new_hat
+ (Dr * stheta)[:, None] * z_new_hat
)
v_trail = (
v
+ (Dv * cbeta * calpha)[:, None] * x_new_hat
+ (Dv * cbeta * salpha)[:, None] * y_new_hat
+ (Dv * sbeta)[:, None] * z_new_hat
)

# Leading arm
x_lead = (
x
- (Dr * ctheta * cphi)[:, None] * x_new_hat
- (Dr * ctheta * sphi)[:, None] * y_new_hat
+ (Dr * stheta)[:, None] * z_new_hat
)
v_lead = (
v
- (Dv * cbeta * calpha)[:, None] * x_new_hat
- (Dv * cbeta * salpha)[:, None] * y_new_hat
+ (Dv * sbeta)[:, None] * z_new_hat
)

return x_lead, v_lead, x_trail, v_trail
93 changes: 93 additions & 0 deletions src/galax/dynamics/_src/mockstream/df_fardal15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["Fardal15StreamDF"]


from functools import partial
from typing import final

import jax
import jax.random as jr
from jaxtyping import PRNGKeyArray

import coordinax as cx
import quaxed.numpy as jnp

import galax._custom_types as gt
import galax.potential as gp
from .df_base import AbstractStreamDF
from galax.dynamics._src.api import omega
from galax.dynamics._src.cluster.radius import tidal_radius

# ============================================================
# Constants

kr_bar = 2.0
kvphi_bar = 0.3

kz_bar = 0.0
kvz_bar = 0.0

sigma_kr = 0.5 # TODO: use actual Fardal values
sigma_kvphi = 0.5 # TODO: use actual Fardal values
sigma_kz = 0.5
sigma_kvz = 0.5

# ============================================================


@final
class Fardal15StreamDF(AbstractStreamDF):
"""Fardal Stream Distribution Function.

A class for representing the Fardal+2015 distribution function for
generating stellar streams based on Fardal et al. 2015
https://ui.adsabs.harvard.edu/abs/2015MNRAS.452..301F/abstract
"""

@partial(jax.jit)
def sample(
self,
key: PRNGKeyArray,
potential: gp.AbstractPotential,
x: gt.BBtQuSz3,
v: gt.BBtQuSz3,
prog_mass: gt.BBtFloatQuSz0,
t: gt.BBtFloatQuSz0,
) -> tuple[gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3, gt.BtQuSz3]:
"""Generate stream particle initial conditions."""
# Random number generation
key1, key2, key3, key4 = jr.split(key, 4)

om = omega(x, v)[..., None]

# r-hat
r_hat = cx.vecs.normalize_vector(x)

r_tidal = tidal_radius(potential, x, v, mass=prog_mass, t=t)[..., None]
v_circ = om * r_tidal # relative velocity

# z-hat
L_vec = jnp.linalg.cross(x, v)
z_hat = cx.vecs.normalize_vector(L_vec)

# phi-hat
phi_vec = v - jnp.sum(v * r_hat, axis=-1, keepdims=True) * r_hat
phi_hat = cx.vecs.normalize_vector(phi_vec)

# k vals
shape = r_tidal.shape
kr_samp = kr_bar + jr.normal(key1, shape) * sigma_kr
kvphi_samp = kr_samp * (kvphi_bar + jr.normal(key2, shape) * sigma_kvphi)
kz_samp = kz_bar + jr.normal(key3, shape) * sigma_kz
kvz_samp = kvz_bar + jr.normal(key4, shape) * sigma_kvz

# Trailing arm
x_trail = x + r_tidal * (kr_samp * r_hat + kz_samp * z_hat)
v_trail = v + v_circ * (kvphi_samp * phi_hat + kvz_samp * z_hat)

# Leading arm
x_lead = x - r_tidal * (kr_samp * r_hat - kz_samp * z_hat)
v_lead = v - v_circ * (kvphi_samp * phi_hat - kvz_samp * z_hat)

return x_lead, v_lead, x_trail, v_trail
Loading
Loading