Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float |
title="Bulk modulus density plot",
x_label="Reference bulk modulus / GPa",
y_label="Predicted bulk modulus / GPa",
annotation_metadata={"excluded": "Excluded"},
)
def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]:
"""
Expand All @@ -165,6 +166,7 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]
title="Shear modulus density plot",
x_label="Reference shear modulus / GPa",
y_label="Predicted shear modulus / GPa",
annotation_metadata={"excluded": "Excluded"},
)
def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]:
"""
Expand Down
66 changes: 53 additions & 13 deletions ml_peg/analysis/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ def plot_density_scatter(
grid_size: int = 80,
max_points_per_cell: int = 5,
seed: int = 0,
hover_metadata: dict[str, str] | None = None,
annotation_metadata: dict[str, str] | None = None,
) -> Callable:
"""
Plot density-coloured parity scatter with legend-based model toggling.
Expand Down Expand Up @@ -571,6 +573,15 @@ def plot_density_scatter(
Maximum number of examples plotted per cell to keep renders responsive.
seed
Seed for deterministic sub-sampling. Default is 0.
hover_metadata
Dictionary mapping metadata keys to display labels for hover tooltips.
Keys are used to look up values in each point's metadata; labels are shown
in the hover text. Pass ``None`` (default) to omit additional hover metadata.
annotation_metadata
Dictionary mapping metadata keys to display labels for model-level
annotations (shown in the text box on the plot). Keys are used to look up
values in the model's metadata dict; labels are shown in the annotation.
Pass ``None`` (default) to omit additional annotation metadata.

Returns
-------
Expand Down Expand Up @@ -687,23 +698,49 @@ def _downsample(
global_max = -np.inf
processed = {}
annotations = []
annotation_fields = annotation_metadata or {}
hover_fields = hover_metadata or {}

for model in results:
data = results[model]
ref_vals = np.asarray(data.get("ref", []), dtype=float)
pred_vals = np.asarray(data.get("pred", []), dtype=float)
meta = data.get("meta") or {}
excluded = meta.get("excluded")
excluded_text = str(excluded) if excluded is not None else "n/a"

# Extract annotation metadata values (for text box)
annotation_values: list[str] = []
for meta_key in annotation_fields:
meta_raw = meta.get(meta_key)
annotation_values.append(
"n/a" if meta_raw is None else str(meta_raw)
)

# Extract hover metadata values (for tooltips)
hover_values: list[str] = []
for meta_key in hover_fields:
meta_raw = meta.get(meta_key)
hover_values.append("n/a" if meta_raw is None else str(meta_raw))

if ref_vals.size == 0 or pred_vals.size == 0:
sampled = ([], [], [])
else:
sampled = _downsample(ref_vals, pred_vals)
global_min = min(global_min, ref_vals.min(), pred_vals.min())
global_max = max(global_max, ref_vals.max(), pred_vals.max())
# Top left corner annotation for each model with exclusion info

# Build annotation text from annotation metadata
summary_text = ""
if annotation_fields:
summary_text = " | ".join(
f"{label}: {value}"
for value, label in zip(
annotation_values, annotation_fields.values(), strict=True
)
)
annotations.append(
{
"text": f"{model} | Excluded: {excluded_text}",
"text": f"{model}"
+ (f" | {summary_text}" if summary_text else ""),
"xref": "paper",
"yref": "paper",
"x": 0.02,
Expand All @@ -717,7 +754,7 @@ def _downsample(
processed[model] = {
"samples": sampled,
"counts": len(ref_vals),
"meta": excluded_text,
"meta": hover_values if hover_fields else None,
}

if not np.isfinite(global_min) or not np.isfinite(global_max):
Expand All @@ -730,12 +767,15 @@ def _downsample(
line_end = global_max + padding

fig = go.Figure()
hovertemplate = (
"<b>Reference:</b> %{x:.3f}<br>"
"<b>Predicted:</b> %{y:.3f}<br>"
"<b>Density:</b> %{customdata[0]:.0f}<br>"
"<b>Excluded:</b> %{meta[0]}<extra></extra>"
)
hover_lines = [
"<b>Reference:</b> %{x:.3f}",
"<b>Predicted:</b> %{y:.3f}",
"<b>Density:</b> %{customdata[0]:.0f}",
]
if hover_fields:
for idx, label in enumerate(hover_fields.values()):
hover_lines.append(f"<b>{label}:</b> %{{meta[{idx}]}}")
hovertemplate = "<br>".join(hover_lines) + "<extra></extra>"

for idx, model in enumerate(results):
sample_x, sample_y, density = processed[model]["samples"]
Expand All @@ -756,7 +796,7 @@ def _downsample(
customdata=np.array(density, dtype=float)[:, None]
if density
else None,
meta=[processed[model]["meta"]],
meta=processed[model]["meta"],
hovertemplate=hovertemplate,
)
)
Expand All @@ -783,7 +823,7 @@ def _downsample(
title={"text": title} if title else None,
xaxis={"title": {"text": x_label}},
yaxis={"title": {"text": y_label}},
annotations=[annotations[0]],
annotations=[annotations[0]] if annotations else [],
meta=layout_meta,
showlegend=True,
legend_title_text="Model",
Expand Down