diff --git a/mcx/distributions/__init__.py b/mcx/distributions/__init__.py index 94742715..73d44c4d 100644 --- a/mcx/distributions/__init__.py +++ b/mcx/distributions/__init__.py @@ -12,6 +12,7 @@ from .lognormal import LogNormal from .mvnormal import MvNormal from .normal import Normal +from .pareto import Pareto from .poisson import Poisson from .uniform import Uniform @@ -30,6 +31,7 @@ "MvNormal", "HalfNormal", "Normal", + "Pareto", "Poisson", "Uniform", ] diff --git a/mcx/distributions/pareto.py b/mcx/distributions/pareto.py new file mode 100644 index 00000000..295e9a06 --- /dev/null +++ b/mcx/distributions/pareto.py @@ -0,0 +1,35 @@ +from jax import lax +from jax import numpy as jnp +from jax import random +from jax.scipy import stats + +from mcx.distributions import constraints +from mcx.distributions.distribution import Distribution +from mcx.distributions.shapes import promote_shapes + + +class Pareto(Distribution): + parameters = { + "shape": constraints.strictly_positive, + "scale": constraints.strictly_positive, + } + + def __init__(self, shape, scale, loc=0): + self.support = constraints.closed_interval(scale, jnp.inf) + self.event_shape = () + shape, scale, loc = promote_shapes(shape, scale, loc) + batch_shape = lax.broadcast_shapes( + jnp.shape(shape), jnp.shape(scale), jnp.shape(loc) + ) + self.batch_shape = batch_shape + self.shape = jnp.broadcast_to(shape, batch_shape) + self.scale = jnp.broadcast_to(scale, batch_shape) + self.loc = jnp.broadcast_to(loc, batch_shape) + + def sample(self, rng_key, sample_shape=()): + shape = sample_shape + self.batch_shape + self.event_shape + return self.scale * (random.pareto(key=rng_key, b=self.shape, shape=shape)) + + @constraints.limit_to_support + def logpdf(self, x): + return stats.pareto.logpdf(x=x, b=self.shape, loc=self.loc, scale=self.scale) diff --git a/tests/distributions/pareto_test.py b/tests/distributions/pareto_test.py new file mode 100644 index 00000000..e6dc155d --- /dev/null +++ b/tests/distributions/pareto_test.py @@ -0,0 +1,244 @@ +import pytest +from jax import numpy as jnp +from jax import random + +from mcx.distributions import Pareto + + +@pytest.fixture +def rng_key(): + return random.PRNGKey(0) + + +def pareto_mean(shape, scale): + if shape > 1: + return shape * scale / (shape - 1.0) + else: + return jnp.inf + + +def pareto_variance(shape, scale): + if shape <= 2: + return jnp.inf + else: + numerator = (scale ** 2) * shape + denominator = ((shape - 1) ** 2) * (shape - 2) + return numerator / denominator + + +# +# SAMPLING CORRECTNESS +# + +sample_means = [ + {"shape": 2, "scale": 0.1, "expected": pareto_mean(shape=2, scale=0.1)}, + {"shape": 10, "scale": 1, "expected": pareto_mean(shape=10, scale=1)}, + {"shape": 10, "scale": 10, "expected": pareto_mean(shape=10, scale=10)}, + {"shape": 100, "scale": 10, "expected": pareto_mean(shape=100, scale=10)}, +] + + +@pytest.mark.parametrize("case", sample_means) +def test_sample_mean(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, (1_000_000,) + ) + avg = jnp.mean(samples, axis=0).item() + assert avg == pytest.approx(case["expected"], abs=1e-2) + + +sample_variances = [ + {"shape": 2, "scale": 0.1, "expected": pareto_variance(shape=2, scale=0.1)}, + {"shape": 10, "scale": 1, "expected": pareto_variance(shape=10, scale=1)}, + {"shape": 10, "scale": 10, "expected": pareto_variance(shape=10, scale=10)}, + {"shape": 100, "scale": 10, "expected": pareto_variance(shape=100, scale=10)}, +] + + +@pytest.mark.parametrize("case", sample_variances) +def test_sample_variance(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, (1_000_000,) + ) + var = jnp.var(samples, axis=0).item() + assert var == pytest.approx(case["expected"], abs=1e-2) + + +# +# LOGPDF CORRECTNESS +# + +out_of_support_cases = [ + {"shape": 3, "scale": 1, "x": 0.5, "expected": -jnp.inf}, + {"shape": 3, "scale": 1, "x": -1, "expected": -jnp.inf}, +] + + +@pytest.mark.parametrize("case", out_of_support_cases) +def test_logpdf_out_of_support(case): + logprob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"]) + assert logprob == case["expected"] + + +# +# LOGPDF SHAPES +# + +expected_logpdf_shapes = [ + {"shape": 3, "scale": 1, "x": 1, "expected_shape": ()}, + {"shape": 3, "scale": 1, "x": jnp.array([1, 2]), "expected_shape": (2,)}, +] + + +@pytest.mark.parametrize("case", expected_logpdf_shapes) +def test_logpdf_shape(case): + log_prob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"]) + assert log_prob.shape == case["expected_shape"] + + +# +# SAMPLING SHAPE +# + +# +# SAMPLING SHAPES +# + +scalar_argument_expected_shapes = [ + {"sample_shape": (), "expected_shape": ()}, + {"sample_shape": (100,), "expected_shape": (100,)}, + { + "sample_shape": (100, 10), + "expected_shape": ( + 100, + 10, + ), + }, + { + "sample_shape": (1, 100), + "expected_shape": ( + 1, + 100, + ), + }, +] + + +@pytest.mark.parametrize("case", scalar_argument_expected_shapes) +def test_sample_shape_scalar_arguments(rng_key, case): + """Test the correctness of broadcasting when both arguments are + scalars. We test scalars arguments separately from array arguments + since scalars are edge cases when it comes to broadcasting. + + """ + samples = Pareto(scale=1, shape=1).sample(rng_key, case["sample_shape"]) + assert samples.shape == case["expected_shape"] + + +array_argument_expected_shapes_zero_dim = [ + { + "shape": 1, + "scale": jnp.array([1, 2, 3]), + "sample_shape": (), + "expected_shape": (3,), + }, + { + "shape": jnp.array([1, 2, 3]), + "scale": 1, + "sample_shape": (), + "expected_shape": (3,), + }, + { + "shape": 1, + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (), + "expected_shape": (2, 2), + }, + { + "shape": jnp.array([1, 2]), + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (), + "expected_shape": (2, 2), + }, +] + + +@pytest.mark.parametrize("case", array_argument_expected_shapes_zero_dim) +def test_sample_shape_array_arguments_no_sample_shape(rng_key, case): + """Test the correctness of broadcasting when arguments can be arrays.""" + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, case["sample_shape"] + ) + assert samples.shape == case["expected_shape"] + + +array_argument_expected_shapes_one_dim = [ + { + "shape": 1, + "scale": jnp.array([1, 2, 3]), + "sample_shape": (100,), + "expected_shape": (100, 3), + }, + { + "shape": jnp.array([1, 2, 3]), + "scale": 1, + "sample_shape": (100,), + "expected_shape": (100, 3), + }, + { + "shape": 1, + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100,), + "expected_shape": (100, 2, 2), + }, + { + "shape": jnp.array([1, 2]), + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100,), + "expected_shape": (100, 2, 2), + }, +] + + +@pytest.mark.parametrize("case", array_argument_expected_shapes_one_dim) +def test_sample_shape_array_arguments_1d_sample_shape(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, case["sample_shape"] + ) + assert samples.shape == case["expected_shape"] + + +array_argument_expected_shapes_two_dims = [ + { + "shape": 1, + "scale": jnp.array([1, 2, 3]), + "sample_shape": (100, 2), + "expected_shape": (100, 2, 3), + }, + { + "shape": jnp.array([1, 2, 3]), + "scale": 1, + "sample_shape": (100, 3), + "expected_shape": (100, 3, 3), + }, + { + "shape": 1, + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100, 2), + "expected_shape": (100, 2, 2, 2), + }, + { + "shape": jnp.array([1, 2]), + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100, 2), + "expected_shape": (100, 2, 2, 2), + }, +] + + +@pytest.mark.parametrize("case", array_argument_expected_shapes_two_dims) +def test_sample_shape_array_arguments_2d_sample_shape(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, case["sample_shape"] + ) + assert samples.shape == case["expected_shape"]