-
Notifications
You must be signed in to change notification settings - Fork 16
Add pareto distribution #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tblazina
wants to merge
9
commits into
rlouf:master
Choose a base branch
from
tblazina:add-pareto-distribution
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
74eb7b0
Add pareto distribution
tblazina 03ee596
Fix sample method
tblazina d41346b
Update parameter names in Pareto distribution
tblazina 957a6f3
Merge branch 'master' into add-pareto-distribution
tblazina 1b9dd28
Update pareto implementation
tblazina 7de71bf
Add some first tests for Pareto distribution
tblazina ea00e36
Fix import sorting and linting errors
tblazina 403d897
Add more tests for Pareto distribution
tblazina dadb1c7
Merge branch 'master' into add-pareto-distribution
tblazina File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from jax import lax | ||
| from jax import numpy as jnp | ||
| from jax import random | ||
| from jax.scipy import stats | ||
|
|
||
| from mcx.distributions import constraints | ||
| from mcx.distributions.distribution import Distribution | ||
| from mcx.distributions.shapes import promote_shapes | ||
|
|
||
|
|
||
| class Pareto(Distribution): | ||
| parameters = { | ||
| "shape": constraints.strictly_positive, | ||
| "scale": constraints.strictly_positive, | ||
| } | ||
|
|
||
| def __init__(self, shape, scale, loc=0): | ||
| self.support = constraints.closed_interval(scale, jnp.inf) | ||
| self.event_shape = () | ||
| shape, scale, loc = promote_shapes(shape, scale, loc) | ||
| batch_shape = lax.broadcast_shapes( | ||
| jnp.shape(shape), jnp.shape(scale), jnp.shape(loc) | ||
| ) | ||
| self.batch_shape = batch_shape | ||
| self.shape = jnp.broadcast_to(shape, batch_shape) | ||
| self.scale = jnp.broadcast_to(scale, batch_shape) | ||
| self.loc = jnp.broadcast_to(loc, batch_shape) | ||
|
|
||
| def sample(self, rng_key, sample_shape=()): | ||
| shape = sample_shape + self.batch_shape + self.event_shape | ||
| return self.scale * (random.pareto(key=rng_key, b=self.shape, shape=shape)) | ||
|
|
||
| @constraints.limit_to_support | ||
| def logpdf(self, x): | ||
| return stats.pareto.logpdf(x=x, b=self.shape, loc=self.loc, scale=self.scale) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,244 @@ | ||
| import pytest | ||
| from jax import numpy as jnp | ||
| from jax import random | ||
|
|
||
| from mcx.distributions import Pareto | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def rng_key(): | ||
| return random.PRNGKey(0) | ||
|
|
||
|
|
||
| def pareto_mean(shape, scale): | ||
| if shape > 1: | ||
| return shape * scale / (shape - 1.0) | ||
| else: | ||
| return jnp.inf | ||
|
|
||
|
|
||
| def pareto_variance(shape, scale): | ||
| if shape <= 2: | ||
| return jnp.inf | ||
| else: | ||
| numerator = (scale ** 2) * shape | ||
| denominator = ((shape - 1) ** 2) * (shape - 2) | ||
| return numerator / denominator | ||
|
|
||
|
|
||
| # | ||
| # SAMPLING CORRECTNESS | ||
| # | ||
|
|
||
| sample_means = [ | ||
| {"shape": 2, "scale": 0.1, "expected": pareto_mean(shape=2, scale=0.1)}, | ||
| {"shape": 10, "scale": 1, "expected": pareto_mean(shape=10, scale=1)}, | ||
| {"shape": 10, "scale": 10, "expected": pareto_mean(shape=10, scale=10)}, | ||
| {"shape": 100, "scale": 10, "expected": pareto_mean(shape=100, scale=10)}, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", sample_means) | ||
| def test_sample_mean(rng_key, case): | ||
| samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( | ||
| rng_key, (1_000_000,) | ||
| ) | ||
| avg = jnp.mean(samples, axis=0).item() | ||
| assert avg == pytest.approx(case["expected"], abs=1e-2) | ||
|
|
||
|
|
||
| sample_variances = [ | ||
| {"shape": 2, "scale": 0.1, "expected": pareto_variance(shape=2, scale=0.1)}, | ||
| {"shape": 10, "scale": 1, "expected": pareto_variance(shape=10, scale=1)}, | ||
| {"shape": 10, "scale": 10, "expected": pareto_variance(shape=10, scale=10)}, | ||
| {"shape": 100, "scale": 10, "expected": pareto_variance(shape=100, scale=10)}, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", sample_variances) | ||
| def test_sample_variance(rng_key, case): | ||
| samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( | ||
| rng_key, (1_000_000,) | ||
| ) | ||
| var = jnp.var(samples, axis=0).item() | ||
| assert var == pytest.approx(case["expected"], abs=1e-2) | ||
|
|
||
|
|
||
| # | ||
| # LOGPDF CORRECTNESS | ||
| # | ||
|
|
||
| out_of_support_cases = [ | ||
| {"shape": 3, "scale": 1, "x": 0.5, "expected": -jnp.inf}, | ||
| {"shape": 3, "scale": 1, "x": -1, "expected": -jnp.inf}, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", out_of_support_cases) | ||
| def test_logpdf_out_of_support(case): | ||
| logprob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"]) | ||
| assert logprob == case["expected"] | ||
|
|
||
|
|
||
| # | ||
| # LOGPDF SHAPES | ||
| # | ||
|
|
||
| expected_logpdf_shapes = [ | ||
| {"shape": 3, "scale": 1, "x": 1, "expected_shape": ()}, | ||
| {"shape": 3, "scale": 1, "x": jnp.array([1, 2]), "expected_shape": (2,)}, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", expected_logpdf_shapes) | ||
| def test_logpdf_shape(case): | ||
| log_prob = Pareto(shape=case["shape"], scale=case["scale"]).logpdf(case["x"]) | ||
| assert log_prob.shape == case["expected_shape"] | ||
|
|
||
|
|
||
| # | ||
| # SAMPLING SHAPE | ||
| # | ||
|
|
||
| # | ||
| # SAMPLING SHAPES | ||
| # | ||
|
|
||
| scalar_argument_expected_shapes = [ | ||
| {"sample_shape": (), "expected_shape": ()}, | ||
| {"sample_shape": (100,), "expected_shape": (100,)}, | ||
| { | ||
| "sample_shape": (100, 10), | ||
| "expected_shape": ( | ||
| 100, | ||
| 10, | ||
| ), | ||
| }, | ||
| { | ||
| "sample_shape": (1, 100), | ||
| "expected_shape": ( | ||
| 1, | ||
| 100, | ||
| ), | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", scalar_argument_expected_shapes) | ||
| def test_sample_shape_scalar_arguments(rng_key, case): | ||
| """Test the correctness of broadcasting when both arguments are | ||
| scalars. We test scalars arguments separately from array arguments | ||
| since scalars are edge cases when it comes to broadcasting. | ||
|
|
||
| """ | ||
| samples = Pareto(scale=1, shape=1).sample(rng_key, case["sample_shape"]) | ||
| assert samples.shape == case["expected_shape"] | ||
|
|
||
|
|
||
| array_argument_expected_shapes_zero_dim = [ | ||
| { | ||
| "shape": 1, | ||
| "scale": jnp.array([1, 2, 3]), | ||
| "sample_shape": (), | ||
| "expected_shape": (3,), | ||
| }, | ||
| { | ||
| "shape": jnp.array([1, 2, 3]), | ||
| "scale": 1, | ||
| "sample_shape": (), | ||
| "expected_shape": (3,), | ||
| }, | ||
| { | ||
| "shape": 1, | ||
| "scale": jnp.array([[1, 2], [3, 4]]), | ||
| "sample_shape": (), | ||
| "expected_shape": (2, 2), | ||
| }, | ||
| { | ||
| "shape": jnp.array([1, 2]), | ||
| "scale": jnp.array([[1, 2], [3, 4]]), | ||
| "sample_shape": (), | ||
| "expected_shape": (2, 2), | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", array_argument_expected_shapes_zero_dim) | ||
| def test_sample_shape_array_arguments_no_sample_shape(rng_key, case): | ||
| """Test the correctness of broadcasting when arguments can be arrays.""" | ||
| samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( | ||
| rng_key, case["sample_shape"] | ||
| ) | ||
| assert samples.shape == case["expected_shape"] | ||
|
|
||
|
|
||
| array_argument_expected_shapes_one_dim = [ | ||
| { | ||
| "shape": 1, | ||
| "scale": jnp.array([1, 2, 3]), | ||
| "sample_shape": (100,), | ||
| "expected_shape": (100, 3), | ||
| }, | ||
| { | ||
| "shape": jnp.array([1, 2, 3]), | ||
| "scale": 1, | ||
| "sample_shape": (100,), | ||
| "expected_shape": (100, 3), | ||
| }, | ||
| { | ||
| "shape": 1, | ||
| "scale": jnp.array([[1, 2], [3, 4]]), | ||
| "sample_shape": (100,), | ||
| "expected_shape": (100, 2, 2), | ||
| }, | ||
| { | ||
| "shape": jnp.array([1, 2]), | ||
| "scale": jnp.array([[1, 2], [3, 4]]), | ||
| "sample_shape": (100,), | ||
| "expected_shape": (100, 2, 2), | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", array_argument_expected_shapes_one_dim) | ||
| def test_sample_shape_array_arguments_1d_sample_shape(rng_key, case): | ||
| samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( | ||
| rng_key, case["sample_shape"] | ||
| ) | ||
| assert samples.shape == case["expected_shape"] | ||
|
|
||
|
|
||
| array_argument_expected_shapes_two_dims = [ | ||
| { | ||
| "shape": 1, | ||
| "scale": jnp.array([1, 2, 3]), | ||
| "sample_shape": (100, 2), | ||
| "expected_shape": (100, 2, 3), | ||
| }, | ||
| { | ||
| "shape": jnp.array([1, 2, 3]), | ||
| "scale": 1, | ||
| "sample_shape": (100, 3), | ||
| "expected_shape": (100, 3, 3), | ||
| }, | ||
| { | ||
| "shape": 1, | ||
| "scale": jnp.array([[1, 2], [3, 4]]), | ||
| "sample_shape": (100, 2), | ||
| "expected_shape": (100, 2, 2, 2), | ||
| }, | ||
| { | ||
| "shape": jnp.array([1, 2]), | ||
| "scale": jnp.array([[1, 2], [3, 4]]), | ||
| "sample_shape": (100, 2), | ||
| "expected_shape": (100, 2, 2, 2), | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("case", array_argument_expected_shapes_two_dims) | ||
| def test_sample_shape_array_arguments_2d_sample_shape(rng_key, case): | ||
| samples = Pareto(shape=case["shape"], scale=case["scale"]).sample( | ||
| rng_key, case["sample_shape"] | ||
| ) | ||
| assert samples.shape == case["expected_shape"] | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great addition! However, before I merge we'll need to add tests for the shape and the support! Would you mind adding those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, was planning on it when I get some time, hopefully in the next few days!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything I can do to help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll let you know when I get to it this weekend. Last 2.5 weeks I had a kidney stone which involved two surgeries and like 5 nights in hospitals, but things seem to be resolved now. 2021 has not been my year in terms of health. Nonetheless, I should finally have some time this weekend!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I found some time to add more tests - but I'm having one issue with a failing test for the variance in the case when the shape parameter is <= 2 and I'm not entirely sure what I've implemented wrong. Not being totally familiar with the Pareto distribution i've kind of just followed the information on https://en.wikipedia.org/wiki/Pareto_distribution which is stating that the variance should be infinite when the shape parameter is <= 2, however this is not the case in the current implementation. I'd appreciate some feedback!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extensive test suite, great job!
Remember that we defined$shape < 1$ per the fomulae above.
shape = bin this case. The variance should thus be theoretically infinite whenThen, if you measure the variance of samples drawn from the distribution, you should get a very large number but not strictly$\infty$ . You can check that $\sigma > 10 \mu$ for instance when $0 < shape < 1$ . It would also be nice to check that $\mu \rightarrow \infty$ when $shape < 0$ .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I'll update the tests to reflect this. Thanks for the clarification!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry was away from this too long and am a bit confused because in your suggestion you are using$\sigma$ and $\mu$ notation, and I'm a bit confused as to what you are referring to, when you say "The variance should thus be theoretically infinite when $shape < 1$ per the fomulae above." I'm not sure what formulae you are exactly referring too because in the way I've implemented it, having a $shape < 1$ doesn't result in the variance being infinite:
I get that for that variance of the samples won't strictly be$\infty$ , but I think I have implemented the Pareto distribution incorrectly but can't figure out what I've done wrong. Would need some additional assistance, thanks!