diff --git a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py
index 6472f0717..b0e07848f 100644
--- a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py
+++ b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py
@@ -141,6 +141,7 @@ def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float |
title="Bulk modulus density plot",
x_label="Reference bulk modulus / GPa",
y_label="Predicted bulk modulus / GPa",
+ annotation_metadata={"excluded": "Excluded"},
)
def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]:
"""
@@ -165,6 +166,7 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]
title="Shear modulus density plot",
x_label="Reference shear modulus / GPa",
y_label="Predicted shear modulus / GPa",
+ annotation_metadata={"excluded": "Excluded"},
)
def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]:
"""
diff --git a/ml_peg/analysis/utils/decorators.py b/ml_peg/analysis/utils/decorators.py
index 3641aab51..34c166eda 100644
--- a/ml_peg/analysis/utils/decorators.py
+++ b/ml_peg/analysis/utils/decorators.py
@@ -544,6 +544,8 @@ def plot_density_scatter(
grid_size: int = 80,
max_points_per_cell: int = 5,
seed: int = 0,
+ hover_metadata: dict[str, str] | None = None,
+ annotation_metadata: dict[str, str] | None = None,
) -> Callable:
"""
Plot density-coloured parity scatter with legend-based model toggling.
@@ -571,6 +573,15 @@ def plot_density_scatter(
Maximum number of examples plotted per cell to keep renders responsive.
seed
Seed for deterministic sub-sampling. Default is 0.
+ hover_metadata
+ Dictionary mapping metadata keys to display labels for hover tooltips.
+ Keys are used to look up values in each point's metadata; labels are shown
+ in the hover text. Pass ``None`` (default) to omit additional hover metadata.
+ annotation_metadata
+ Dictionary mapping metadata keys to display labels for model-level
+ annotations (shown in the text box on the plot). Keys are used to look up
+ values in the model's metadata dict; labels are shown in the annotation.
+ Pass ``None`` (default) to omit additional annotation metadata.
Returns
-------
@@ -687,23 +698,49 @@ def _downsample(
global_max = -np.inf
processed = {}
annotations = []
+ annotation_fields = annotation_metadata or {}
+ hover_fields = hover_metadata or {}
+
for model in results:
data = results[model]
ref_vals = np.asarray(data.get("ref", []), dtype=float)
pred_vals = np.asarray(data.get("pred", []), dtype=float)
meta = data.get("meta") or {}
- excluded = meta.get("excluded")
- excluded_text = str(excluded) if excluded is not None else "n/a"
+
+ # Extract annotation metadata values (for text box)
+ annotation_values: list[str] = []
+ for meta_key in annotation_fields:
+ meta_raw = meta.get(meta_key)
+ annotation_values.append(
+ "n/a" if meta_raw is None else str(meta_raw)
+ )
+
+ # Extract hover metadata values (for tooltips)
+ hover_values: list[str] = []
+ for meta_key in hover_fields:
+ meta_raw = meta.get(meta_key)
+ hover_values.append("n/a" if meta_raw is None else str(meta_raw))
+
if ref_vals.size == 0 or pred_vals.size == 0:
sampled = ([], [], [])
else:
sampled = _downsample(ref_vals, pred_vals)
global_min = min(global_min, ref_vals.min(), pred_vals.min())
global_max = max(global_max, ref_vals.max(), pred_vals.max())
- # Top left corner annotation for each model with exclusion info
+
+ # Build annotation text from annotation metadata
+ summary_text = ""
+ if annotation_fields:
+ summary_text = " | ".join(
+ f"{label}: {value}"
+ for value, label in zip(
+ annotation_values, annotation_fields.values(), strict=True
+ )
+ )
annotations.append(
{
- "text": f"{model} | Excluded: {excluded_text}",
+ "text": f"{model}"
+ + (f" | {summary_text}" if summary_text else ""),
"xref": "paper",
"yref": "paper",
"x": 0.02,
@@ -717,7 +754,7 @@ def _downsample(
processed[model] = {
"samples": sampled,
"counts": len(ref_vals),
- "meta": excluded_text,
+ "meta": hover_values if hover_fields else None,
}
if not np.isfinite(global_min) or not np.isfinite(global_max):
@@ -730,12 +767,15 @@ def _downsample(
line_end = global_max + padding
fig = go.Figure()
- hovertemplate = (
- "Reference: %{x:.3f}
"
- "Predicted: %{y:.3f}
"
- "Density: %{customdata[0]:.0f}
"
- "Excluded: %{meta[0]}"
- )
+ hover_lines = [
+ "Reference: %{x:.3f}",
+ "Predicted: %{y:.3f}",
+ "Density: %{customdata[0]:.0f}",
+ ]
+ if hover_fields:
+ for idx, label in enumerate(hover_fields.values()):
+ hover_lines.append(f"{label}: %{{meta[{idx}]}}")
+ hovertemplate = "
".join(hover_lines) + ""
for idx, model in enumerate(results):
sample_x, sample_y, density = processed[model]["samples"]
@@ -756,7 +796,7 @@ def _downsample(
customdata=np.array(density, dtype=float)[:, None]
if density
else None,
- meta=[processed[model]["meta"]],
+ meta=processed[model]["meta"],
hovertemplate=hovertemplate,
)
)
@@ -783,7 +823,7 @@ def _downsample(
title={"text": title} if title else None,
xaxis={"title": {"text": x_label}},
yaxis={"title": {"text": y_label}},
- annotations=[annotations[0]],
+ annotations=[annotations[0]] if annotations else [],
meta=layout_meta,
showlegend=True,
legend_title_text="Model",