diff --git a/src/galax/dynamics/_src/cluster/api.py b/src/galax/dynamics/_src/cluster/api.py index de378d99..ad49be8e 100644 --- a/src/galax/dynamics/_src/cluster/api.py +++ b/src/galax/dynamics/_src/cluster/api.py @@ -139,31 +139,37 @@ def relaxation_time(*args: Any, **kwargs: Any) -> u.AbstractQuantity: >>> M = u.Quantity(1e4, "Msun") >>> r_hm = u.Quantity(2, "pc") >>> m_avg = u.Quantity(0.5, "Msun") - >>> G = u.Quantity(4.30091e-3, "pc km2 / (s2 Msun)") - >>> gdc.relaxation_time(M, r_hm, m_avg, G=G).uconvert("Myr") - Quantity(Array(129.50788873, dtype=float64, weak_type=True), unit='Myr') + >>> gdc.relaxation_time(M, r_hm, m_avg=m_avg).uconvert("Myr") + Quantity(Array(129.50777927, dtype=float64), unit='Myr') There are many different definitions of the relaxation time. By passing a flag object you can choose the one you want. Let's work through the built-in options: + >>> flags = gdc.relax_time # (not only flags) + - Baumgardt (1998) (the default): - >>> gdc.relaxation_time(gdc.relax_time.Baumgardt1998, M, r_hm, m_avg, G=G).uconvert("Myr") - Quantity(Array(129.50788873, dtype=float64, ...), unit='Myr') + >>> gdc.relaxation_time(flags.Baumgardt1998, M, r_hm, m_avg=m_avg).uconvert("Myr") + Quantity(Array(129.50777927, dtype=float64), unit='Myr') + + - Spitzer and Hart (1971): + + >>> gdc.relaxation_time(flags.SpitzerHart1971, M, r_hm, m_avg=m_avg).uconvert("Myr") + Quantity(Array(151.23177551, dtype=float64), unit='Myr') - Spitzer (1987) half-mass: >>> lnLambda = 10 # very approximate - >>> gdc.relaxation_time(gdc.relax_time.Spitzer1987HalfMass, M, r_hm, m_avg, lnLambda=lnLambda, G=G).uconvert("Myr") - Quantity(Array(143.38057289, dtype=float64, weak_type=True), unit='Myr') + >>> gdc.relaxation_time(flags.Spitzer1987HalfMass, M, r_hm, m_avg=m_avg, lnLambda=lnLambda).uconvert("Myr") + Quantity(Array(143.38045171, dtype=float64), unit='Myr') - Spitzer (1987) core: >>> Mcore, r_c = M / 5, r_hm / 5 # very approximate - >>> gdc.relaxation_time(gdc.relax_time.Spitzer1987Core, Mcore, r_c, m_avg, lnLambda=lnLambda, G=G).uconvert("Myr") - Quantity(Array(11.47044583, dtype=float64, weak_type=True), unit='Myr') + >>> gdc.relaxation_time(flags.Spitzer1987Core, Mcore, r_c, m_avg=m_avg, lnLambda=lnLambda).uconvert("Myr") + Quantity(Array(11.47043614, dtype=float64), unit='Myr') Using multiple-dispatch, you can register your own relaxation time definition. diff --git a/src/galax/dynamics/_src/cluster/dmdt.py b/src/galax/dynamics/_src/cluster/dmdt.py index 8bbb6795..f8a4bfd0 100644 --- a/src/galax/dynamics/_src/cluster/dmdt.py +++ b/src/galax/dynamics/_src/cluster/dmdt.py @@ -332,8 +332,8 @@ def __call__(self, t: Time, M: ClusterMass, args: Any, /, **kw: Any) -> Array: self.relaxation_time_flag, Mq, args["r_hm"], - args["m_avg"], - G=pot.constants["G"], + m_avg=args["m_avg"], + constants=pot.constants, ).uconvert("Myr") r_ratio = u.ustrip("", args["r_hm"] / r_t) diff --git a/src/galax/dynamics/_src/cluster/relax_time.py b/src/galax/dynamics/_src/cluster/relax_time.py index 6e34313c..6f4d1c1b 100644 --- a/src/galax/dynamics/_src/cluster/relax_time.py +++ b/src/galax/dynamics/_src/cluster/relax_time.py @@ -9,27 +9,29 @@ "AbstractRelaxationTimeMethod", # specific methods "Baumgardt1998", - "relaxation_time_baumgardt1998", "SpitzerHart1971", - "relaxation_time_spitzer_hart_1971", "Spitzer1987HalfMass", - "half_mass_relaxation_time_spitzer1987", "Spitzer1987Core", - "core_relaxation_time_spitzer1987", ] +import abc import functools as ft -from typing import Annotated as Antd, Any, NoReturn, TypeAlias, TypeVar, cast, final +from dataclasses import KW_ONLY +from typing import Annotated as Antd, Any, TypeAlias, TypeVar, cast, final from typing_extensions import Doc import equinox as eqx import jax +from jaxtyping import ArrayLike from plum import dispatch import quaxed.numpy as jnp +import unxt as u from unxt.quantity import is_any_quantity +from xmmutablemap import ImmutableMap import galax._custom_types as gt +from galax.potential._src.base import default_constants BBtAorQSz0: TypeAlias = gt.BBtSz0 | gt.BBtQuSz0 T = TypeVar("T", bound=gt.BBtSz0 | gt.BBtQuSz0) @@ -48,33 +50,20 @@ def _check_types_match(obj: T, comparator: object, /, name: str) -> T: ##################################################################### -class AbstractRelaxationTimeMethod: - """Abstract base class for relaxation time flags. +class AbstractRelaxationTimeMethod(eqx.Module): # type: ignore[misc] + """Abstract base class for relaxation time flags.""" - Examples - -------- - >>> import galax.dynamics.cluster as gdc - - >>> try: gdc.relax_time.AbstractRelaxationTimeMethod() - ... except TypeError as e: print(e) - Cannot instantiate AbstractRelaxationTimeMethod - - """ + @abc.abstractmethod + def __call__(self, *args: Any, **kwds: Any) -> BBtAorQSz0: + pass - def __new__(cls) -> NoReturn: - msg = "Cannot instantiate AbstractRelaxationTimeMethod" - raise TypeError(msg) - -@dispatch.multi( - (gt.BBtSz0, gt.BBtSz0, gt.BBtSz0), - (gt.BBtQuSz0, gt.BBtQuSz0, gt.BBtQuSz0), -) +@dispatch.multi((gt.BBtSz0, gt.BBtSz0), (gt.BBtQuSz0, gt.BBtQuSz0)) def relaxation_time( - M: BBtAorQSz0, r_hm: BBtAorQSz0, m_avg: BBtAorQSz0, /, **kw: Any + M: BBtAorQSz0, r_hm: BBtAorQSz0, /, *, m_avg: BBtAorQSz0, **kw: Any ) -> BBtAorQSz0: """Compute relaxation time, defaulting to Baumgardt (1998) formula.""" - return relaxation_time_baumgardt1998(M, r_hm, m_avg, **kw) + return relaxation_time(Baumgardt1998, M, r_hm, m_avg=m_avg, **kw) ###################################################################### @@ -99,14 +88,64 @@ class SpitzerHart1971(AbstractRelaxationTimeMethod): >>> M = u.Quantity(1e4, "Msun") >>> r_hm = u.Quantity(2, "pc") >>> m_avg = u.Quantity(0.42, "Msun") - >>> G = u.Quantity(0.00449, "pc3 / (Myr2 Msun)") - >>> trh = gdc.relaxation_time(gdc.relax_time.SpitzerHart1971, M, r_hm, - ... m_avg=m_avg, gamma=0.11, G=G) - >>> print(trh) - Quantity['time'](176.21612725, unit='Myr') + >>> trh = gdc.relax_time.SpitzerHart1971(m_avg=m_avg, gamma=0.11)(M, r_hm) + >>> print(trh.uconvert("Myr")) + Quantity['time'](176.0495246, unit='Myr') """ + m_avg: u.AbstractQuantity | ArrayLike + """Average stellar mass.""" + + gamma: float = 0.11 + """Constant in the Coulomb logarithm.""" + + _: KW_ONLY + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) + + @ft.partial(jax.jit) + def __call__( + self, + M: Antd[BBtAorQSz0, Doc("mass of the cluster")], + r_hm: Antd[BBtAorQSz0, Doc("half-mass radius of the cluster")], + /, + ) -> BBtAorQSz0: + r"""Compute relaxation time using Spitzer and Hart (1971) formula. + + $$ t_{\mathrm{rh}} = \frac{0.138 \sqrt{N} r_{\mathrm{h}}^{3/2}} + {\sqrt{M G} \ln(\gamma N)} + $$ + + where: + + - $N = m / \bar{M}$ is the mean number of stars in the cluster, + - $r_h$ is the half-mass radius of the cluster, + - $\bar{M}$ is the mean stellar mass. For a Chabrier (2005) IMF between 0.08 + and 100 $M_{\odot}$ this is approximately 0.42 $M_{\odot}$, + - $G$ is the gravitational constant, + - $\ln(\gamma N)$ is the Coulomb logarithm. For equal-mass clusters (Giersz + & Heggie 1994) $\gamma \sim 0.11$. + + Examples + -------- + >>> import unxt as u + >>> import galax.dynamics.cluster as gdc + + >>> M = u.Quantity(1e4, "Msun") + >>> r_hm = u.Quantity(2, "pc") + >>> m_avg = u.Quantity(0.42, "Msun") + >>> trh = gdc.relax_time.SpitzerHart1971(m_avg=m_avg, gamma=0.11)(M, r_hm) + >>> print(trh.uconvert("Myr")) + Quantity['time'](176.0495246, unit='Myr') + + """ + G = self.constants["G"] # TODO: unit detection + N = M / self.m_avg + coulomb_log = jnp.log(self.gamma * N) + return 0.138 * jnp.sqrt(N * r_hm**3 / (G * self.m_avg)) / coulomb_log + @dispatch.multi( (type[SpitzerHart1971], gt.BBtSz0, gt.BBtSz0), @@ -117,7 +156,8 @@ def relaxation_time( M: BBtAorQSz0, r_hm: BBtAorQSz0, /, - m_avg: BBtAorQSz0, + m_avg: u.AbstractQuantity | ArrayLike, + gamma: float = 0.11, **kw: Any, ) -> BBtAorQSz0: r"""Compute relaxation time using Spitzer and Hart (1971) formula. @@ -137,220 +177,200 @@ def relaxation_time( >>> m_avg = u.Quantity(0.42, "Msun") >>> G = u.Quantity(0.00449, "pc3 / (Myr2 Msun)") >>> trh = gdc.relaxation_time(gdc.relax_time.SpitzerHart1971, M, r_hm, - ... m_avg=m_avg, gamma=0.11, G=G) - >>> print(trh) - Quantity['time'](176.21612725, unit='Myr') + ... m_avg=m_avg, gamma=0.11) + >>> print(trh.uconvert("Myr")) + Quantity['time'](176.0495246, unit='Myr') """ - return relaxation_time_spitzer_hart_1971(M, r_hm, m_avg=m_avg, **kw) + return SpitzerHart1971(m_avg=m_avg, gamma=gamma, **kw)(M, r_hm) -# --------------------------- +###################################################################### +# Spitzer 1987 relaxation time -@ft.partial(jax.jit) -def relaxation_time_spitzer_hart_1971( - M: Antd[BBtAorQSz0, Doc("mass of the cluster")], - r_hm: Antd[BBtAorQSz0, Doc("half-mass radius of the cluster")], - /, - *, - m_avg: Antd[float, Doc("mean stellar mass.")] = 0.42, - gamma: Antd[float, Doc("Coulomb logarithm term.")] = 0.11, - G: Antd[BBtAorQSz0, Doc("gravitational constant")], +def _relaxation_time_spitzer1987( + M: BBtAorQSz0, + r: BBtAorQSz0, + m_avg: BBtAorQSz0, + prefactor: float, + lnLambda: gt.RealScalarLike, + G: BBtAorQSz0, ) -> BBtAorQSz0: - r"""Compute relaxation time using Spitzer and Hart (1971) formula. + r = _check_types_match(r, M, name="r") + m_avg = _check_types_match(m_avg, M, name="m_avg") + G = _check_types_match(G, M, name="G") + N = M / m_avg + return jnp.sqrt(r**3 / G / M) * prefactor * N / lnLambda - $$ t_{\mathrm{rh}} = \frac{0.138 \sqrt{N} r_{\mathrm{h}}^{3/2}} - {\sqrt{M G} \ln(\gamma N)} - $$ - where: +@final +class Spitzer1987HalfMass(AbstractRelaxationTimeMethod): + r"""Half-mass relaxation time from Spitzer (1987). - - $N = m / \bar{M}$ is the mean number of stars in the cluster, - - $r_h$ is the half-mass radius of the cluster, - - $\bar{M}$ is the mean stellar mass. For a Chabrier (2005) IMF between 0.08 - and 100 $M_{\odot}$ this is approximately 0.42 $M_{\odot}$, - - $G$ is the gravitational constant, - - $\ln(\gamma N)$ is the Coulomb logarithm. For equal-mass clusters (Giersz - & Heggie 1994) $\gamma \sim 0.11$. + $$ t_{rh} \approx \frac{0.17 N}{\ln(\Lambda)} \sqrt{\frac{r_h^3}{G M}} $$ - Examples - -------- - >>> import unxt as u - >>> import galax.dynamics.cluster as gdc + """ - >>> M = u.Quantity(1e4, "Msun") - >>> r_hm = u.Quantity(2, "pc") - >>> m_avg = u.Quantity(0.42, "Msun") - >>> G = u.Quantity(0.00449, "pc3 / (Myr2 Msun)") - >>> trh = gdc.relax_time.relaxation_time_spitzer_hart_1971( - ... M, r_hm, m_avg=m_avg, gamma=0.11, G=G) - >>> print(trh) - Quantity['time'](176.21612725, unit='Myr') + m_avg: u.AbstractQuantity | ArrayLike + """Average stellar mass.""" - """ - N = M / m_avg - return 0.138 * jnp.sqrt(N * r_hm**3 / (G * m_avg)) / jnp.log(gamma * N) + lnLambda: gt.RealScalarLike # noqa: N815 + """Coulomb logarithm.""" + _: KW_ONLY + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) -###################################################################### -# Spitzer 1987 relaxation time + @ft.partial(jax.jit) + def __call__( + self, + M: Antd[BBtAorQSz0, Doc("mass of the cluster")], + r_hm: Antd[BBtAorQSz0, Doc("half-mass radius of the cluster")], + /, + ) -> BBtAorQSz0: + r"""Compute the cluster's relaxation time. + Spitzer 1987 Equation 1. -@final -class Spitzer1987HalfMass(AbstractRelaxationTimeMethod): - r"""Half-mass relaxation time from Spitzer (1987). + .. math:: - $$ t_{rh} \approx \frac{0.17 N}{\ln(\Lambda)} \sqrt{\frac{r_h^3}{G M}} $$ + t_r = \frac{0.1 N}{\ln(0.4 N)} \frac{r_{hm}^3}{G M} - """ + Examples + -------- + >>> import unxt as u + >>> import galax.dynamics.cluster as gdc + >>> M = u.Quantity(1e4, "Msun") + >>> r_hm = u.Quantity(2, "pc") + >>> m_avg = u.Quantity(0.5, "Msun") + >>> lnLambda = 10 -@final -class Spitzer1987Core(AbstractRelaxationTimeMethod): - r"""Core relaxation time from Spitzer (1987). + >>> func = gdc.relax_time.Spitzer1987HalfMass(m_avg, lnLambda=lnLambda) + >>> func(M, r_hm).uconvert("Myr") + Quantity(Array(143.38045171, dtype=float64), unit='Myr') - $$ t_{rc} \approx \frac{0.34 N}{\ln(\Lambda)} \sqrt{\frac{r_c^3}{G M_c}} $$ + The function also works with raw JAX arrays, in which case the + inputs are assumed to be in compatible units: - """ + >>> func = gdc.relax_time.Spitzer1987HalfMass(m_avg.value, lnLambda=lnLambda, constants={"G": 0.00449}) + >>> func(M.value, r_hm.value) + Array(143.51613833, dtype=float64, ...) + + """ # noqa: E501 + return _relaxation_time_spitzer1987( + M, + r_hm, + self.m_avg, + prefactor=0.17, + lnLambda=self.lnLambda, + G=self.constants["G"], + ) @dispatch.multi( - (type[Spitzer1987HalfMass], gt.BBtSz0, gt.BBtSz0, gt.BBtSz0), - (type[Spitzer1987HalfMass], gt.BBtQuSz0, gt.BBtQuSz0, gt.BBtQuSz0), + (type[Spitzer1987HalfMass], gt.BBtSz0, gt.BBtSz0), + (type[Spitzer1987HalfMass], gt.BBtQuSz0, gt.BBtQuSz0), ) def relaxation_time( _: type[Spitzer1987HalfMass], M: BBtAorQSz0, r_hm: BBtAorQSz0, - m_avg: BBtAorQSz0, /, - **kw: Any, -) -> BBtAorQSz0: - """Compute relaxation time using Spitzer (1987) formula.""" - return half_mass_relaxation_time_spitzer1987(M, r_hm, m_avg, **kw) - - -@dispatch.multi( - (type[Spitzer1987Core], gt.BBtSz0, gt.BBtSz0, gt.BBtSz0), - (type[Spitzer1987Core], gt.BBtQuSz0, gt.BBtQuSz0, gt.BBtQuSz0), -) -def relaxation_time( - _: type[Spitzer1987Core], - M_core: BBtAorQSz0, - r_core: BBtAorQSz0, + *, m_avg: BBtAorQSz0, - /, **kw: Any, ) -> BBtAorQSz0: """Compute relaxation time using Spitzer (1987) formula.""" - return core_relaxation_time_spitzer1987(M_core, r_core, m_avg, **kw) - - -# --------------------------- - - -def _relaxation_time_spitzer1987( - M: BBtAorQSz0, - r: BBtAorQSz0, - m_avg: BBtAorQSz0, - prefactor: float, - lnLambda: gt.RealScalarLike, - G: BBtAorQSz0, -) -> BBtAorQSz0: - r = _check_types_match(r, M, name="r") - m_avg = _check_types_match(m_avg, M, name="m_avg") - G = _check_types_match(G, M, name="G") - N = M / m_avg - return jnp.sqrt(r**3 / G / M) * prefactor * N / lnLambda - - -@ft.partial(jax.jit) -def half_mass_relaxation_time_spitzer1987( - M: Antd[BBtAorQSz0, Doc("mass of the cluster")], - r_hm: Antd[BBtAorQSz0, Doc("half-mass radius of the cluster")], - m_avg: Antd[BBtAorQSz0, Doc("average stellar mass")], - /, - *, - G: Antd[BBtAorQSz0, Doc("gravitational constant")], - lnLambda: Antd[gt.RealScalarLike, Doc("Coulomb logarithm")], -) -> BBtAorQSz0: - r"""Compute the cluster's relaxation time. + return Spitzer1987HalfMass(m_avg=m_avg, **kw)(M, r_hm) - Spitzer 1987 Equation 1. - .. math:: +# --------------------------------------------------------- - t_r = \frac{0.1 N}{\ln(0.4 N)} \frac{r_{hm}^3}{G M} - Examples - -------- - >>> import unxt as u - >>> import galax.dynamics.cluster as gdc +@final +class Spitzer1987Core(AbstractRelaxationTimeMethod): + r"""Core relaxation time from Spitzer (1987). - >>> M = u.Quantity(1e4, "Msun") - >>> r_hm = u.Quantity(2, "pc") - >>> m_avg = u.Quantity(0.5, "Msun") - >>> G = u.Quantity(0.00449, "pc3 / (Myr2 Msun)") - >>> lnLambda = 10 + $$ t_{rc} \approx \frac{0.34 N}{\ln(\Lambda)} \sqrt{\frac{r_c^3}{G M_c}} $$ - >>> gdc.relax_time.half_mass_relaxation_time_spitzer1987(M, r_hm, m_avg, G=G, lnLambda=lnLambda).uconvert("Myr") - Quantity(Array(143.51613833, dtype=float64, ...), unit='Myr') + """ - The function also works with raw JAX arrays, in which case the - inputs are assumed to be in compatible units: + m_avg: u.AbstractQuantity | ArrayLike + """Average stellar mass.""" - >>> gdc.relax_time.half_mass_relaxation_time_spitzer1987(M.value, r_hm.value, m_avg.value, G=0.00449, lnLambda=lnLambda) - Array(143.51613833, dtype=float64, ...) + lnLambda: gt.RealScalarLike # noqa: N815 + """Coulomb logarithm.""" - """ # noqa: E501 - return _relaxation_time_spitzer1987( - M, r_hm, m_avg, prefactor=0.17, lnLambda=lnLambda, G=G + _: KW_ONLY + constants: ImmutableMap[str, u.AbstractQuantity] = eqx.field( + default=default_constants, converter=ImmutableMap ) + @ft.partial(jax.jit) + def __call__( + self, + Mc: Antd[BBtAorQSz0, Doc("mass of the cluster")], + r_c: Antd[BBtAorQSz0, Doc("core radius of the cluster")], + /, + ) -> BBtAorQSz0: + r"""Compute the cluster's relaxation time. -@ft.partial(jax.jit) -def core_relaxation_time_spitzer1987( - Mc: Antd[BBtAorQSz0, Doc("mass of the cluster")], - r_c: Antd[BBtAorQSz0, Doc("core radius of the cluster")], - m_avg: Antd[BBtAorQSz0, Doc("average stellar mass")], - /, - *, - G: Antd[BBtAorQSz0, Doc("gravitational constant")], - lnLambda: Antd[gt.RealScalarLike, Doc("Coulomb logarithm")], -) -> BBtAorQSz0: - r"""Compute the cluster's relaxation time. + Spitzer 1987 Equation 2. - Spitzer 1987 Equation 2. + .. math:: - .. math:: + t_r = \frac{0.2 N}{\ln(0.4 N)} \frac{r_c^3}{G M_c} - t_r = \frac{0.2 N}{\ln(0.4 N)} \frac{r_c^3}{G M_c} + Examples + -------- + >>> import unxt as u + >>> import galax.dynamics.cluster as gdc - Examples - -------- - >>> import unxt as u - >>> import galax.dynamics.cluster as gdc + >>> M = u.Quantity(2e3, "Msun") + >>> r_hm = u.Quantity(0.1, "pc") + >>> m_avg = u.Quantity(0.5, "Msun") + >>> lnLambda = 10 - >>> M = u.Quantity(2e3, "Msun") - >>> r_hm = u.Quantity(0.1, "pc") - >>> m_avg = u.Quantity(0.5, "Msun") - >>> G = u.Quantity(0.00449, "pc3 / (Myr2 Msun)") - >>> lnLambda = 10 + >>> func = gdc.relax_time.Spitzer1987Core(m_avg, lnLambda=lnLambda) + >>> func(M, r_hm).uconvert("Myr") + Quantity(Array(1.43380452, dtype=float64), unit='Myr') - >>> gdc.relax_time.core_relaxation_time_spitzer1987(M, r_hm, m_avg, G=G, lnLambda=lnLambda).uconvert("Myr") - Quantity(Array(1.43516138, dtype=float64, ...), unit='Myr') + The function also works with raw JAX arrays, in which case the + inputs are assumed to be in compatible units: - The function also works with raw JAX arrays, in which case the - inputs are assumed to be in compatible units: + >>> func = gdc.relax_time.Spitzer1987Core(m_avg.value, lnLambda=lnLambda, constants={"G": 0.00449}) + >>> func(M.value, r_hm.value) + Array(1.43516138, dtype=float64, ...) - >>> gdc.relax_time.core_relaxation_time_spitzer1987(M.value, r_hm.value, m_avg.value, G=0.00449, lnLambda=lnLambda) - Array(1.43516138, dtype=float64, ...) + """ # noqa: E501 + return _relaxation_time_spitzer1987( + Mc, + r_c, + self.m_avg, + prefactor=0.34, + lnLambda=self.lnLambda, + G=self.constants["G"], + ) - """ # noqa: E501 - return _relaxation_time_spitzer1987( - Mc, r_c, m_avg, prefactor=0.34, lnLambda=lnLambda, G=G - ) + +@dispatch.multi( + (type[Spitzer1987Core], gt.BBtSz0, gt.BBtSz0), + (type[Spitzer1987Core], gt.BBtQuSz0, gt.BBtQuSz0), +) +def relaxation_time( + _: type[Spitzer1987Core], + M_core: BBtAorQSz0, + r_core: BBtAorQSz0, + /, + *, + m_avg: BBtAorQSz0, + **kw: Any, +) -> BBtAorQSz0: + """Compute relaxation time using Spitzer (1987) formula.""" + return Spitzer1987Core(m_avg=m_avg, **kw)(M_core, r_core) ###################################################################### @@ -366,71 +386,69 @@ class Baumgardt1998(AbstractRelaxationTimeMethod): """ + m_avg: u.AbstractQuantity | ArrayLike + """Average stellar mass.""" + _: KW_ONLY + constants: ImmutableMap[str, u.AbstractQuantity | ArrayLike] = eqx.field( + default=default_constants, converter=ImmutableMap + ) -@dispatch.multi( - (type[Baumgardt1998], gt.BBtSz0, gt.BBtSz0, gt.BBtSz0), - (type[Baumgardt1998], gt.BBtQuSz0, gt.BBtQuSz0, gt.BBtQuSz0), -) -def relaxation_time( - _: type[Baumgardt1998], - M: BBtAorQSz0, - r_hm: BBtAorQSz0, - m_avg: BBtAorQSz0, - /, - **kw: Any, -) -> BBtAorQSz0: - """Compute relaxation time using Baumgardt (1998) formula.""" - return relaxation_time_baumgardt1998(M, r_hm, m_avg, **kw) - - -# --------------------------- + @ft.partial(jax.jit) + def __call__( + self, + M: Antd[BBtAorQSz0, Doc("mass of the cluster")], + r_hm: Antd[BBtAorQSz0, Doc("half-mass radius of the cluster")], + /, + ) -> BBtAorQSz0: + r"""Compute the cluster's relaxation time. + Baumgardt 1998 Equation 1. -@dispatch.multi( - (gt.BBtSz0, gt.BBtSz0, gt.BBtSz0), - (gt.BBtQuSz0, gt.BBtQuSz0, gt.BBtQuSz0), -) -@ft.partial(jax.jit) -def relaxation_time_baumgardt1998( - M: Antd[BBtAorQSz0, Doc("mass of the cluster")], - r_hm: Antd[BBtAorQSz0, Doc("half-mass radius of the cluster")], - m_avg: Antd[BBtAorQSz0, Doc("average stellar mass")], - /, - *, - G: Antd[BBtAorQSz0, Doc("gravitational constant")], -) -> BBtAorQSz0: - r"""Compute the cluster's relaxation time. + $$ + t_r = \frac{0.138 \sqrt{M_c} r_{hm}^{3/2}}{\sqrt{G} m_{avg} \ln(0.4 N)} + $$ - Baumgardt 1998 Equation 1. + where $N$ is the number of stars in the cluster, $M_c$ is the mass of the + cluster, $r_{hm}$ is the half-mass radius of the cluster, $m_{avg}$ is the + average stellar mass, and $G$ is the gravitational constant. - $$ - t_r = \frac{0.138 \sqrt{M_c} r_{hm}^{3/2}}{\sqrt{G} m_{avg} \ln(0.4 N)} - $$ + Examples + -------- + >>> import unxt as u + >>> import galax.dynamics.cluster as gdc - where $N$ is the number of stars in the cluster, $M_c$ is the mass of the - cluster, $r_{hm}$ is the half-mass radius of the cluster, $m_{avg}$ is the - average stellar mass, and $G$ is the gravitational constant. + >>> M = u.Quantity(1e4, "Msun") + >>> r_hm = u.Quantity(2, "pc") + >>> m_avg = u.Quantity(0.5, "Msun") - Examples - -------- - >>> import unxt as u - >>> import galax.dynamics.cluster as gdc + >>> gdc.relax_time.Baumgardt1998(m_avg)(M, r_hm).uconvert("Myr") + Quantity(Array(129.50777927, dtype=float64), unit='Myr') - >>> M = u.Quantity(1e4, "Msun") - >>> r_hm = u.Quantity(2, "pc") - >>> m_avg = u.Quantity(0.5, "Msun") - >>> G = u.Quantity(0.00449, "pc3 / (Myr2 Msun)") + The function also works with raw JAX arrays, in which case the + inputs are assumed to be in compatible units: - >>> gdc.relax_time.relaxation_time_baumgardt1998(M, r_hm, m_avg, G=G).uconvert("Myr") - Quantity(Array(129.63033763, dtype=float64, ...), unit='Myr') + >>> func = gdc.relax_time.Baumgardt1998(m_avg.value, constants={"G": 0.00449}) + >>> func(M.value, r_hm.value) + Array(129.63033763, dtype=float64, ...) - The function also works with raw JAX arrays, in which case the - inputs are assumed to be in compatible units: + """ + G = _check_types_match(self.constants["G"], M, name="G") + N = M / self.m_avg + return 0.138 * jnp.sqrt(N * r_hm**3 / (G * self.m_avg)) / jnp.log(0.4 * N) - >>> gdc.relax_time.relaxation_time_baumgardt1998(M.value, r_hm.value, m_avg.value, G=0.00449) - Array(129.63033763, dtype=float64, ...) - """ # noqa: E501 - G = _check_types_match(G, M, name="G") - N = M / m_avg - return 0.138 * jnp.sqrt(N * r_hm**3 / (G * m_avg)) / jnp.log(0.4 * N) +@dispatch.multi( + (type[Baumgardt1998], gt.BBtSz0, gt.BBtSz0), + (type[Baumgardt1998], gt.BBtQuSz0, gt.BBtQuSz0), +) +def relaxation_time( + _: type[Baumgardt1998], + M: BBtAorQSz0, + r_hm: BBtAorQSz0, + /, + *, + m_avg: BBtAorQSz0, + **kw: Any, +) -> BBtAorQSz0: + """Compute relaxation time using Baumgardt (1998) formula.""" + return Baumgardt1998(m_avg=m_avg, **kw)(M, r_hm)