From 74eb7b012650170ff4f38e0e450ef43758bdc5c1 Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Sun, 14 Feb 2021 23:06:00 +0100 Subject: [PATCH 1/7] Add pareto distribution --- mcx/distributions/pareto.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 mcx/distributions/pareto.py diff --git a/mcx/distributions/pareto.py b/mcx/distributions/pareto.py new file mode 100644 index 00000000..0ab3d0a7 --- /dev/null +++ b/mcx/distributions/pareto.py @@ -0,0 +1,28 @@ +from jax import numpy as np +from jax import random +from jax.scipy import stats + +from mcx.distributions import constraints +from mcx.distributions.distribution import Distribution +from mcx.distributions.shapes import broadcast_batch_shape + + +class Pareto(Distribution): + parameters = { + "b": constraints.strictly_positive, + } + + def __init__(self, a, m): + self.support = constraints.closed_interval(m, np.inf) + self.event_shape = () + self.batch_shape = broadcast_batch_shape(np.shape(a), np.shape(m)) + self.a = a + self.m = m + + def sample(self, rng_key, sample_shape=()): + shape = sample_shape + self.batch_shape + self.event_shape + return random.pareto(key=rng_key, b=self.b, shape=shape) + + @constraints.limit_to_support + def logpdf(self, x): + return stats.pareto.logpdf(x=x, b=self.b, loc=self.m) From 03ee5968e9a24457e3169901d23dec9a6289c209 Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Sun, 14 Feb 2021 23:31:28 +0100 Subject: [PATCH 2/7] Fix sample method This is necessary to account for the jax.random.pareto function using the type II Pareto distribution --- mcx/distributions/pareto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcx/distributions/pareto.py b/mcx/distributions/pareto.py index 0ab3d0a7..3c5fc7c8 100644 --- a/mcx/distributions/pareto.py +++ b/mcx/distributions/pareto.py @@ -21,7 +21,7 @@ def __init__(self, a, m): def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape - return random.pareto(key=rng_key, b=self.b, shape=shape) + return self.m * (1 + random.pareto(key=rng_key, b=self.b, shape=shape)) @constraints.limit_to_support def logpdf(self, x): From d41346b47c1c1c0145633aa9f1770aff44bd1167 Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Thu, 18 Feb 2021 23:15:06 +0100 Subject: [PATCH 3/7] Update parameter names in Pareto distribution Also added Pareto distribution to mcx.distribution init --- mcx/distributions/__init__.py | 2 ++ mcx/distributions/pareto.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/mcx/distributions/__init__.py b/mcx/distributions/__init__.py index 0a18a670..e531e32b 100644 --- a/mcx/distributions/__init__.py +++ b/mcx/distributions/__init__.py @@ -9,6 +9,7 @@ from .lognormal import LogNormal from .mvnormal import MvNormal from .normal import Normal +from .pareto import Pareto from .poisson import Poisson from .uniform import Uniform @@ -23,6 +24,7 @@ "Exponential", "LogNormal", "Normal", + "Pareto", "Poisson", "Uniform", "MvNormal", diff --git a/mcx/distributions/pareto.py b/mcx/distributions/pareto.py index 3c5fc7c8..43222ce2 100644 --- a/mcx/distributions/pareto.py +++ b/mcx/distributions/pareto.py @@ -9,20 +9,22 @@ class Pareto(Distribution): parameters = { - "b": constraints.strictly_positive, + "shape": constraints.strictly_positive, + "scale": constraints.strictly_positive, } - def __init__(self, a, m): - self.support = constraints.closed_interval(m, np.inf) + def __init__(self, shape, scale, loc=0): + self.support = constraints.closed_interval(scale, np.inf) self.event_shape = () - self.batch_shape = broadcast_batch_shape(np.shape(a), np.shape(m)) - self.a = a - self.m = m + self.batch_shape = broadcast_batch_shape(np.shape(shape), np.shape(scale)) + self.shape = shape + self.scale = scale + self.loc = loc def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape - return self.m * (1 + random.pareto(key=rng_key, b=self.b, shape=shape)) + return self.scale * (random.pareto(key=rng_key, b=self.shape, shape=shape)) @constraints.limit_to_support def logpdf(self, x): - return stats.pareto.logpdf(x=x, b=self.b, loc=self.m) + return stats.pareto.logpdf(x=x, b=self.shape, loc=self.loc, scale=self.scale) From 1b9dd2827d037717d038878730b07d9d4b6c3fb6 Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Sun, 7 Mar 2021 22:26:56 +0100 Subject: [PATCH 4/7] Update pareto implementation Add in promo_shapes and brodcast_to --- mcx/distributions/pareto.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mcx/distributions/pareto.py b/mcx/distributions/pareto.py index 43222ce2..51d63580 100644 --- a/mcx/distributions/pareto.py +++ b/mcx/distributions/pareto.py @@ -1,10 +1,10 @@ -from jax import numpy as np -from jax import random +from jax import numpy as jnp +from jax import lax, random from jax.scipy import stats from mcx.distributions import constraints from mcx.distributions.distribution import Distribution -from mcx.distributions.shapes import broadcast_batch_shape +from mcx.distributions.shapes import promote_shapes class Pareto(Distribution): @@ -14,12 +14,14 @@ class Pareto(Distribution): } def __init__(self, shape, scale, loc=0): - self.support = constraints.closed_interval(scale, np.inf) + self.support = constraints.closed_interval(scale, jnp.inf) self.event_shape = () - self.batch_shape = broadcast_batch_shape(np.shape(shape), np.shape(scale)) - self.shape = shape - self.scale = scale - self.loc = loc + shape, scale, loc = promote_shapes(shape, scale, loc) + batch_shape = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale), jnp.shape(loc)) + self.batch_shape = batch_shape + self.shape = jnp.broadcast_to(shape, batch_shape) + self.scale = jnp.broadcast_to(scale, batch_shape) + self.loc = jnp.broadcast_to(loc, batch_shape) def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape From 7de71bf6172b0f70b32b0f66659490b7cef992b7 Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Sun, 7 Mar 2021 22:27:13 +0100 Subject: [PATCH 5/7] Add some first tests for Pareto distribution --- tests/distributions/pareto_test.py | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tests/distributions/pareto_test.py diff --git a/tests/distributions/pareto_test.py b/tests/distributions/pareto_test.py new file mode 100644 index 00000000..5652fd40 --- /dev/null +++ b/tests/distributions/pareto_test.py @@ -0,0 +1,47 @@ +import pytest +from jax import numpy as jnp +from jax import random + +from mcx.distributions import Pareto + + +@pytest.fixture +def rng_key(): + return random.PRNGKey(0) + + +def pareto_mean(shape, scale): + if shape > 1: + return shape * scale / (shape - 1.0) + else: + return jnp.inf + + +def pareto_variance(shape, scale): + if shape <= 2: + return jnp.inf + else: + numerator = (scale ** 2) * shape + denominator = ((shape - 1) **2 ) * (shape - 2) + return numerator / denominator + + +# +# SAMPLING CORRECTNESS +# + +sample_means = [ + {"shape": 2, "scale": 0.1, "expected": pareto_mean(shape=2, scale=0.1)}, + {"shape": 10, "scale": 1, "expected": pareto_mean(shape=10, scale=1)}, + {"shape": 10, "scale": 10, "expected": pareto_mean(shape=10, scale=10)}, + {"shape": 100, "scale": 10, "expected": pareto_mean(shape=100, scale=10)}, +] + + +@pytest.mark.parametrize("case", sample_means) +def test_sample_mean(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, (1_000_000,) + ) + avg = jnp.mean(samples, axis=0).item() + assert avg == pytest.approx(case["expected"], abs=1e-2) From ea00e3682ee9a6e443b2c0533b06f9f8af3409e9 Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Sun, 7 Mar 2021 22:47:46 +0100 Subject: [PATCH 6/7] Fix import sorting and linting errors --- mcx/distributions/pareto.py | 7 +++++-- tests/distributions/pareto_test.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mcx/distributions/pareto.py b/mcx/distributions/pareto.py index 51d63580..295e9a06 100644 --- a/mcx/distributions/pareto.py +++ b/mcx/distributions/pareto.py @@ -1,5 +1,6 @@ +from jax import lax from jax import numpy as jnp -from jax import lax, random +from jax import random from jax.scipy import stats from mcx.distributions import constraints @@ -17,7 +18,9 @@ def __init__(self, shape, scale, loc=0): self.support = constraints.closed_interval(scale, jnp.inf) self.event_shape = () shape, scale, loc = promote_shapes(shape, scale, loc) - batch_shape = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale), jnp.shape(loc)) + batch_shape = lax.broadcast_shapes( + jnp.shape(shape), jnp.shape(scale), jnp.shape(loc) + ) self.batch_shape = batch_shape self.shape = jnp.broadcast_to(shape, batch_shape) self.scale = jnp.broadcast_to(scale, batch_shape) diff --git a/tests/distributions/pareto_test.py b/tests/distributions/pareto_test.py index 5652fd40..22020df2 100644 --- a/tests/distributions/pareto_test.py +++ b/tests/distributions/pareto_test.py @@ -22,7 +22,7 @@ def pareto_variance(shape, scale): return jnp.inf else: numerator = (scale ** 2) * shape - denominator = ((shape - 1) **2 ) * (shape - 2) + denominator = ((shape - 1) ** 2) * (shape - 2) return numerator / denominator @@ -33,7 +33,7 @@ def pareto_variance(shape, scale): sample_means = [ {"shape": 2, "scale": 0.1, "expected": pareto_mean(shape=2, scale=0.1)}, {"shape": 10, "scale": 1, "expected": pareto_mean(shape=10, scale=1)}, - {"shape": 10, "scale": 10, "expected": pareto_mean(shape=10, scale=10)}, + {"shape": 10, "scale": 10, "expected": pareto_mean(shape=10, scale=10)}, {"shape": 100, "scale": 10, "expected": pareto_mean(shape=100, scale=10)}, ] From 403d897de102ec9c64f5fb772b80648b8ae80d3d Mon Sep 17 00:00:00 2001 From: Tim Blazina Date: Wed, 31 Mar 2021 12:23:00 +0200 Subject: [PATCH 7/7] Add more tests for Pareto distribution --- tests/distributions/pareto_test.py | 197 +++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/tests/distributions/pareto_test.py b/tests/distributions/pareto_test.py index 22020df2..e6dc155d 100644 --- a/tests/distributions/pareto_test.py +++ b/tests/distributions/pareto_test.py @@ -45,3 +45,200 @@ def test_sample_mean(rng_key, case): ) avg = jnp.mean(samples, axis=0).item() assert avg == pytest.approx(case["expected"], abs=1e-2) + + +sample_variances = [ + {"shape": 2, "scale": 0.1, "expected": pareto_variance(shape=2, scale=0.1)}, + {"shape": 10, "scale": 1, "expected": pareto_variance(shape=10, scale=1)}, + {"shape": 10, "scale": 10, "expected": pareto_variance(shape=10, scale=10)}, + {"shape": 100, "scale": 10, "expected": pareto_variance(shape=100, scale=10)}, +] + + +@pytest.mark.parametrize("case", sample_variances) +def test_sample_variance(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, (1_000_000,) + ) + var = jnp.var(samples, axis=0).item() + assert var == pytest.approx(case["expected"], abs=1e-2) + + +# +# LOGPDF CORRECTNESS +# + +out_of_support_cases = [ + {"shape": 3, "scale": 1, "x": 0.5, "expected": -jnp.inf}, + {"shape": 3, "scale": 1, "x": -1, "expected": -jnp.inf}, +] + + +@pytest.mark.parametrize("case", out_of_support_cases) +def test_logpdf_out_of_support(case): + logprob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"]) + assert logprob == case["expected"] + + +# +# LOGPDF SHAPES +# + +expected_logpdf_shapes = [ + {"shape": 3, "scale": 1, "x": 1, "expected_shape": ()}, + {"shape": 3, "scale": 1, "x": jnp.array([1, 2]), "expected_shape": (2,)}, +] + + +@pytest.mark.parametrize("case", expected_logpdf_shapes) +def test_logpdf_shape(case): + log_prob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"]) + assert log_prob.shape == case["expected_shape"] + + +# +# SAMPLING SHAPE +# + +# +# SAMPLING SHAPES +# + +scalar_argument_expected_shapes = [ + {"sample_shape": (), "expected_shape": ()}, + {"sample_shape": (100,), "expected_shape": (100,)}, + { + "sample_shape": (100, 10), + "expected_shape": ( + 100, + 10, + ), + }, + { + "sample_shape": (1, 100), + "expected_shape": ( + 1, + 100, + ), + }, +] + + +@pytest.mark.parametrize("case", scalar_argument_expected_shapes) +def test_sample_shape_scalar_arguments(rng_key, case): + """Test the correctness of broadcasting when both arguments are + scalars. We test scalars arguments separately from array arguments + since scalars are edge cases when it comes to broadcasting. + + """ + samples = Pareto(scale=1, shape=1).sample(rng_key, case["sample_shape"]) + assert samples.shape == case["expected_shape"] + + +array_argument_expected_shapes_zero_dim = [ + { + "shape": 1, + "scale": jnp.array([1, 2, 3]), + "sample_shape": (), + "expected_shape": (3,), + }, + { + "shape": jnp.array([1, 2, 3]), + "scale": 1, + "sample_shape": (), + "expected_shape": (3,), + }, + { + "shape": 1, + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (), + "expected_shape": (2, 2), + }, + { + "shape": jnp.array([1, 2]), + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (), + "expected_shape": (2, 2), + }, +] + + +@pytest.mark.parametrize("case", array_argument_expected_shapes_zero_dim) +def test_sample_shape_array_arguments_no_sample_shape(rng_key, case): + """Test the correctness of broadcasting when arguments can be arrays.""" + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, case["sample_shape"] + ) + assert samples.shape == case["expected_shape"] + + +array_argument_expected_shapes_one_dim = [ + { + "shape": 1, + "scale": jnp.array([1, 2, 3]), + "sample_shape": (100,), + "expected_shape": (100, 3), + }, + { + "shape": jnp.array([1, 2, 3]), + "scale": 1, + "sample_shape": (100,), + "expected_shape": (100, 3), + }, + { + "shape": 1, + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100,), + "expected_shape": (100, 2, 2), + }, + { + "shape": jnp.array([1, 2]), + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100,), + "expected_shape": (100, 2, 2), + }, +] + + +@pytest.mark.parametrize("case", array_argument_expected_shapes_one_dim) +def test_sample_shape_array_arguments_1d_sample_shape(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, case["sample_shape"] + ) + assert samples.shape == case["expected_shape"] + + +array_argument_expected_shapes_two_dims = [ + { + "shape": 1, + "scale": jnp.array([1, 2, 3]), + "sample_shape": (100, 2), + "expected_shape": (100, 2, 3), + }, + { + "shape": jnp.array([1, 2, 3]), + "scale": 1, + "sample_shape": (100, 3), + "expected_shape": (100, 3, 3), + }, + { + "shape": 1, + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100, 2), + "expected_shape": (100, 2, 2, 2), + }, + { + "shape": jnp.array([1, 2]), + "scale": jnp.array([[1, 2], [3, 4]]), + "sample_shape": (100, 2), + "expected_shape": (100, 2, 2, 2), + }, +] + + +@pytest.mark.parametrize("case", array_argument_expected_shapes_two_dims) +def test_sample_shape_array_arguments_2d_sample_shape(rng_key, case): + samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( + rng_key, case["sample_shape"] + ) + assert samples.shape == case["expected_shape"]