Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mcx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .lognormal import LogNormal
from .mvnormal import MvNormal
from .normal import Normal
from .pareto import Pareto
from .poisson import Poisson
from .uniform import Uniform

Expand All @@ -30,6 +31,7 @@
"MvNormal",
"HalfNormal",
"Normal",
"Pareto",
"Poisson",
"Uniform",
]
35 changes: 35 additions & 0 deletions mcx/distributions/pareto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from jax import lax
from jax import numpy as jnp
from jax import random
from jax.scipy import stats

from mcx.distributions import constraints
from mcx.distributions.distribution import Distribution
from mcx.distributions.shapes import promote_shapes


class Pareto(Distribution):
parameters = {
"shape": constraints.strictly_positive,
"scale": constraints.strictly_positive,
}

def __init__(self, shape, scale, loc=0):
self.support = constraints.closed_interval(scale, jnp.inf)
self.event_shape = ()
shape, scale, loc = promote_shapes(shape, scale, loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(shape), jnp.shape(scale), jnp.shape(loc)
)
self.batch_shape = batch_shape
self.shape = jnp.broadcast_to(shape, batch_shape)
self.scale = jnp.broadcast_to(scale, batch_shape)
self.loc = jnp.broadcast_to(loc, batch_shape)

def sample(self, rng_key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
return self.scale * (random.pareto(key=rng_key, b=self.shape, shape=shape))

@constraints.limit_to_support
def logpdf(self, x):
return stats.pareto.logpdf(x=x, b=self.shape, loc=self.loc, scale=self.scale)
244 changes: 244 additions & 0 deletions tests/distributions/pareto_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import pytest
from jax import numpy as jnp
from jax import random

from mcx.distributions import Pareto


@pytest.fixture
def rng_key():
return random.PRNGKey(0)


def pareto_mean(shape, scale):
if shape > 1:
return shape * scale / (shape - 1.0)
else:
return jnp.inf


def pareto_variance(shape, scale):
if shape <= 2:
return jnp.inf
else:
numerator = (scale ** 2) * shape
denominator = ((shape - 1) ** 2) * (shape - 2)
return numerator / denominator

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition! However, before I merge we'll need to add tests for the shape and the support! Would you mind adding those?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, was planning on it when I get some time, hopefully in the next few days!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything I can do to help?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let you know when I get to it this weekend. Last 2.5 weeks I had a kidney stone which involved two surgeries and like 5 nights in hospitals, but things seem to be resolved now. 2021 has not been my year in terms of health. Nonetheless, I should finally have some time this weekend!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I found some time to add more tests - but I'm having one issue with a failing test for the variance in the case when the shape parameter is <= 2 and I'm not entirely sure what I've implemented wrong. Not being totally familiar with the Pareto distribution i've kind of just followed the information on https://en.wikipedia.org/wiki/Pareto_distribution which is stating that the variance should be infinite when the shape parameter is <= 2, however this is not the case in the current implementation. I'd appreciate some feedback!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extensive test suite, great job!

Remember that we defined shape = b in this case. The variance should thus be theoretically infinite when $shape &lt; 1$ per the fomulae above.

Then, if you measure the variance of samples drawn from the distribution, you should get a very large number but not strictly $\infty$. You can check that $\sigma &gt; 10 \mu$ for instance when $0 &lt; shape &lt; 1$. It would also be nice to check that $\mu \rightarrow \infty$ when $shape &lt; 0$.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I'll update the tests to reflect this. Thanks for the clarification!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry was away from this too long and am a bit confused because in your suggestion you are using $\sigma$ and $\mu$ notation, and I'm a bit confused as to what you are referring to, when you say "The variance should thus be theoretically infinite when $shape &lt; 1$ per the fomulae above." I'm not sure what formulae you are exactly referring too because in the way I've implemented it, having a $shape &lt; 1$ doesn't result in the variance being infinite:

        numerator = (scale ** 2) * shape
        denominator = ((shape - 1) ** 2) * (shape - 2)
        return numerator / denominator

I get that for that variance of the samples won't strictly be $\infty$, but I think I have implemented the Pareto distribution incorrectly but can't figure out what I've done wrong. Would need some additional assistance, thanks!


#
# SAMPLING CORRECTNESS
#

sample_means = [
{"shape": 2, "scale": 0.1, "expected": pareto_mean(shape=2, scale=0.1)},
{"shape": 10, "scale": 1, "expected": pareto_mean(shape=10, scale=1)},
{"shape": 10, "scale": 10, "expected": pareto_mean(shape=10, scale=10)},
{"shape": 100, "scale": 10, "expected": pareto_mean(shape=100, scale=10)},
]


@pytest.mark.parametrize("case", sample_means)
def test_sample_mean(rng_key, case):
samples = Pareto(shape=case["shape"], scale=case["scale"]).sample(
rng_key, (1_000_000,)
)
avg = jnp.mean(samples, axis=0).item()
assert avg == pytest.approx(case["expected"], abs=1e-2)


sample_variances = [
{"shape": 2, "scale": 0.1, "expected": pareto_variance(shape=2, scale=0.1)},
{"shape": 10, "scale": 1, "expected": pareto_variance(shape=10, scale=1)},
{"shape": 10, "scale": 10, "expected": pareto_variance(shape=10, scale=10)},
{"shape": 100, "scale": 10, "expected": pareto_variance(shape=100, scale=10)},
]


@pytest.mark.parametrize("case", sample_variances)
def test_sample_variance(rng_key, case):
samples = Pareto(shape=case["shape"], scale=case["scale"]).sample(
rng_key, (1_000_000,)
)
var = jnp.var(samples, axis=0).item()
assert var == pytest.approx(case["expected"], abs=1e-2)


#
# LOGPDF CORRECTNESS
#

out_of_support_cases = [
{"shape": 3, "scale": 1, "x": 0.5, "expected": -jnp.inf},
{"shape": 3, "scale": 1, "x": -1, "expected": -jnp.inf},
]


@pytest.mark.parametrize("case", out_of_support_cases)
def test_logpdf_out_of_support(case):
logprob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"])
assert logprob == case["expected"]


#
# LOGPDF SHAPES
#

expected_logpdf_shapes = [
{"shape": 3, "scale": 1, "x": 1, "expected_shape": ()},
{"shape": 3, "scale": 1, "x": jnp.array([1, 2]), "expected_shape": (2,)},
]


@pytest.mark.parametrize("case", expected_logpdf_shapes)
def test_logpdf_shape(case):
log_prob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"])
assert log_prob.shape == case["expected_shape"]


#
# SAMPLING SHAPE
#

#
# SAMPLING SHAPES
#

scalar_argument_expected_shapes = [
{"sample_shape": (), "expected_shape": ()},
{"sample_shape": (100,), "expected_shape": (100,)},
{
"sample_shape": (100, 10),
"expected_shape": (
100,
10,
),
},
{
"sample_shape": (1, 100),
"expected_shape": (
1,
100,
),
},
]


@pytest.mark.parametrize("case", scalar_argument_expected_shapes)
def test_sample_shape_scalar_arguments(rng_key, case):
"""Test the correctness of broadcasting when both arguments are
scalars. We test scalars arguments separately from array arguments
since scalars are edge cases when it comes to broadcasting.

"""
samples = Pareto(scale=1, shape=1).sample(rng_key, case["sample_shape"])
assert samples.shape == case["expected_shape"]


array_argument_expected_shapes_zero_dim = [
{
"shape": 1,
"scale": jnp.array([1, 2, 3]),
"sample_shape": (),
"expected_shape": (3,),
},
{
"shape": jnp.array([1, 2, 3]),
"scale": 1,
"sample_shape": (),
"expected_shape": (3,),
},
{
"shape": 1,
"scale": jnp.array([[1, 2], [3, 4]]),
"sample_shape": (),
"expected_shape": (2, 2),
},
{
"shape": jnp.array([1, 2]),
"scale": jnp.array([[1, 2], [3, 4]]),
"sample_shape": (),
"expected_shape": (2, 2),
},
]


@pytest.mark.parametrize("case", array_argument_expected_shapes_zero_dim)
def test_sample_shape_array_arguments_no_sample_shape(rng_key, case):
"""Test the correctness of broadcasting when arguments can be arrays."""
samples = Pareto(shape=case["shape"], scale=case["scale"]).sample(
rng_key, case["sample_shape"]
)
assert samples.shape == case["expected_shape"]


array_argument_expected_shapes_one_dim = [
{
"shape": 1,
"scale": jnp.array([1, 2, 3]),
"sample_shape": (100,),
"expected_shape": (100, 3),
},
{
"shape": jnp.array([1, 2, 3]),
"scale": 1,
"sample_shape": (100,),
"expected_shape": (100, 3),
},
{
"shape": 1,
"scale": jnp.array([[1, 2], [3, 4]]),
"sample_shape": (100,),
"expected_shape": (100, 2, 2),
},
{
"shape": jnp.array([1, 2]),
"scale": jnp.array([[1, 2], [3, 4]]),
"sample_shape": (100,),
"expected_shape": (100, 2, 2),
},
]


@pytest.mark.parametrize("case", array_argument_expected_shapes_one_dim)
def test_sample_shape_array_arguments_1d_sample_shape(rng_key, case):
samples = Pareto(shape=case["shape"], scale=case["scale"]).sample(
rng_key, case["sample_shape"]
)
assert samples.shape == case["expected_shape"]


array_argument_expected_shapes_two_dims = [
{
"shape": 1,
"scale": jnp.array([1, 2, 3]),
"sample_shape": (100, 2),
"expected_shape": (100, 2, 3),
},
{
"shape": jnp.array([1, 2, 3]),
"scale": 1,
"sample_shape": (100, 3),
"expected_shape": (100, 3, 3),
},
{
"shape": 1,
"scale": jnp.array([[1, 2], [3, 4]]),
"sample_shape": (100, 2),
"expected_shape": (100, 2, 2, 2),
},
{
"shape": jnp.array([1, 2]),
"scale": jnp.array([[1, 2], [3, 4]]),
"sample_shape": (100, 2),
"expected_shape": (100, 2, 2, 2),
},
]


@pytest.mark.parametrize("case", array_argument_expected_shapes_two_dims)
def test_sample_shape_array_arguments_2d_sample_shape(rng_key, case):
samples = Pareto(shape=case["shape"], scale=case["scale"]).sample(
rng_key, case["sample_shape"]
)
assert samples.shape == case["expected_shape"]