Skip to content

Commit faf8f95

Browse files
documentation
1 parent 66ceb34 commit faf8f95

File tree

4 files changed

+136
-51
lines changed

4 files changed

+136
-51
lines changed

docs/examples/plot_5_emcee_arviz_numpyro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
# Using external samples easily
2+
# Using external samples
33
44
`emcee`, `arviz`, and `numpyro` are all popular MCMC packages. ChainConsumer
55
provides class methods to turn results from these packages into chains efficiently.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
# Multimodal distributions
3+
4+
`emcee`, `arviz`, and `numpyro` are all popular MCMC packages. ChainConsumer
5+
provides class methods to turn results from these packages into chains efficiently.
6+
7+
If you want to request support for another type of chain, please open a
8+
[discussion](https://github.com/Samreay/ChainConsumer/discussions) with a code
9+
example, and we can add it in. The brave may even provide a PR!
10+
"""
11+
12+
import numpy as np
13+
import pandas as pd
14+
15+
from chainconsumer import Chain, ChainConsumer
16+
from chainconsumer.statistics import SummaryStatistic
17+
18+
# %%
19+
# First, let's build some dummy data
20+
21+
rng = np.random.default_rng(42)
22+
size = 60_000
23+
24+
eta = rng.normal(loc=0.0, scale=0.8, size=size)
25+
26+
phi = np.asarray(
27+
[rng.gamma(shape=2.5, scale=0.4, size=size // 2) - 3.0, 3.0 - rng.gamma(shape=5.0, scale=0.35, size=(size // 2))]
28+
).flatten()
29+
30+
rng.shuffle(phi)
31+
32+
df = pd.DataFrame({"eta": eta, "phi": phi})
33+
34+
# %%
35+
# To build a multimodal chain, you simply have to pass `multimodal=True` when building the chain. To work, it requires
36+
# you to specify `SummaryStatistic.HDI` as the summary statistic.
37+
38+
chain_multimodal = Chain(
39+
samples=df.copy(),
40+
name="posterior-multimodal",
41+
statistics=SummaryStatistic.HDI,
42+
multimodal=True, # <- Here
43+
)
44+
45+
# %%
46+
# Now, if you add this `Chain` to a plotter, it will try to look for sub-intervals and display them.
47+
48+
cc = ChainConsumer()
49+
cc.add_chain(chain_multimodal)
50+
fig = cc.plotter.plot()
51+
52+
# %%
53+
# Let's compare with what would happen if you don't use a multimodal chain. We use the same data as before but don't
54+
# warn `ChainConsumer` that we expect the chains to be multimodal.
55+
56+
chain_unimodal = Chain(samples=df.copy(), name="posterior-unimodal", statistics=SummaryStatistic.HDI, multimodal=False)
57+
58+
cc.add_chain(chain_unimodal)
59+
fig = cc.plotter.plot()
60+
61+
# %%
62+
# Let's try with even more modes.
63+
64+
eta = np.asarray(
65+
[
66+
rng.normal(loc=-3, scale=0.8, size=size // 3),
67+
rng.normal(loc=0.0, scale=0.8, size=size // 3),
68+
rng.normal(loc=+3, scale=0.8, size=size // 3),
69+
]
70+
).flatten()
71+
72+
73+
rng.shuffle(eta)
74+
75+
df = pd.DataFrame({"eta": eta, "phi": phi})
76+
77+
chain_multimodal = Chain(
78+
samples=df.copy(), name="posterior-multimodal", statistics=SummaryStatistic.HDI, multimodal=True
79+
)
80+
81+
cc = ChainConsumer()
82+
cc.add_chain(chain_multimodal)
83+
fig = cc.plotter.plot()
84+
fig.tight_layout()

src/chainconsumer/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _mask_to_intervals(
4141
if mask[-1]:
4242
ends = np.concatenate((ends, len(mask) - 1))
4343

44-
intervals = [(float(x[s]), float(x[e])) for s, e in zip(starts, ends, strict=False) if x[e] > x[s]]
44+
intervals = [(float(x[s]), float(x[e])) for s, e in zip(starts, ends, strict=True) if x[e] > x[s]]
4545

4646
return intervals
4747

@@ -202,8 +202,8 @@ def get_summary(
202202
203203
Returns:
204204
Mapping from chain name to marginal summaries. Each entry is either a single
205-
:class:`~chainconsumer.analysis.Bound` for unimodal chains or a list of bounds
206-
describing the disjoint HDI bands when ``chain.multimodal`` is set.
205+
[`Bound`][chainconsumer.analysis.Bound] for unimodal chains or a list of bounds
206+
describing the disjoint HDI bands when `chain.multimodal` is set.
207207
"""
208208
results = {}
209209
if chains is None:

src/chainconsumer/plotter.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -961,59 +961,60 @@ def _plot_bars(
961961
if chain.bar_shade:
962962
base_bound = self.parent.analysis.get_parameter_summary(chain, column)
963963

964-
if chain.multimodal:
965-
intervals = self.parent.analysis.get_parameter_hdi_intervals(chain, column)
966-
display_bounds = self.parent.analysis.get_parameter_multimodal_bounds(
967-
chain,
968-
column,
969-
intervals,
970-
)
971-
972-
# If we get a single interval, fallback to unimodal HDI
973-
if len(intervals) < 2:
974-
intervals = [(base_bound.lower, base_bound.upper)]
975-
976-
else:
977-
display_bounds = base_bound
978-
intervals = [(display_bounds.lower, display_bounds.upper)]
979-
intervals = np.clip(intervals, a_min=xs.min(), a_max=xs.max())
980-
981-
for lower, upper_ in intervals:
982-
x = np.linspace(lower, upper_, 1000)
983-
984-
if flip:
985-
ax.fill_betweenx(
986-
x,
987-
np.zeros_like(x),
988-
interpolator(x),
989-
color=chain.color,
990-
alpha=0.2,
991-
zorder=chain.zorder,
964+
if base_bound is not None and base_bound.lower is not None and base_bound.upper is not None:
965+
if chain.multimodal:
966+
intervals = self.parent.analysis.get_parameter_hdi_intervals(chain, column)
967+
display_bounds = self.parent.analysis.get_parameter_multimodal_bounds(
968+
chain,
969+
column,
970+
intervals,
992971
)
993972

973+
# If we get a single interval, fallback to unimodal HDI
974+
if len(intervals) < 2:
975+
intervals = [(base_bound.lower, base_bound.upper)]
976+
994977
else:
995-
ax.fill_between(
996-
x,
997-
np.zeros_like(x),
998-
interpolator(x),
999-
color=chain.color,
1000-
alpha=0.2,
1001-
zorder=chain.zorder,
1002-
)
978+
display_bounds = base_bound
979+
intervals = [(display_bounds.lower, display_bounds.upper)]
980+
intervals = np.clip(intervals, a_min=xs.min(), a_max=xs.max())
981+
982+
for lower, upper_ in intervals:
983+
x = np.linspace(lower, upper_, 1000)
984+
985+
if flip:
986+
ax.fill_betweenx(
987+
x,
988+
np.zeros_like(x),
989+
interpolator(x),
990+
color=chain.color,
991+
alpha=0.2,
992+
zorder=chain.zorder,
993+
)
1003994

1004-
if summary:
1005-
label = self.config.get_label(column)
1006-
label_core = label.strip("$") if isinstance(column, str) else None
1007-
label_text = label_core or None
995+
else:
996+
ax.fill_between(
997+
x,
998+
np.zeros_like(x),
999+
interpolator(x),
1000+
color=chain.color,
1001+
alpha=0.2,
1002+
zorder=chain.zorder,
1003+
)
10081004

1009-
title = self.parent.analysis.get_parameter_text(
1010-
display_bounds,
1011-
wrap=True,
1012-
label=label_text,
1013-
)
1005+
if summary:
1006+
label = self.config.get_label(column)
1007+
label_core = label.strip("$") if isinstance(column, str) else None
1008+
label_text = label_core or None
1009+
1010+
title = self.parent.analysis.get_parameter_text(
1011+
display_bounds,
1012+
wrap=True,
1013+
label=label_text,
1014+
)
10141015

1015-
if title:
1016-
ax.set_title(title, fontsize=self.config.summary_font_size)
1016+
if title:
1017+
ax.set_title(title, fontsize=self.config.summary_font_size)
10171018

10181019
return float(ys.max())
10191020

0 commit comments

Comments
 (0)