diff --git a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py index 6472f0717..b0e07848f 100644 --- a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py +++ b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py @@ -141,6 +141,7 @@ def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | title="Bulk modulus density plot", x_label="Reference bulk modulus / GPa", y_label="Predicted bulk modulus / GPa", + annotation_metadata={"excluded": "Excluded"}, ) def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: """ @@ -165,6 +166,7 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict] title="Shear modulus density plot", x_label="Reference shear modulus / GPa", y_label="Predicted shear modulus / GPa", + annotation_metadata={"excluded": "Excluded"}, ) def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: """ diff --git a/ml_peg/analysis/utils/decorators.py b/ml_peg/analysis/utils/decorators.py index 3641aab51..34c166eda 100644 --- a/ml_peg/analysis/utils/decorators.py +++ b/ml_peg/analysis/utils/decorators.py @@ -544,6 +544,8 @@ def plot_density_scatter( grid_size: int = 80, max_points_per_cell: int = 5, seed: int = 0, + hover_metadata: dict[str, str] | None = None, + annotation_metadata: dict[str, str] | None = None, ) -> Callable: """ Plot density-coloured parity scatter with legend-based model toggling. @@ -571,6 +573,15 @@ def plot_density_scatter( Maximum number of examples plotted per cell to keep renders responsive. seed Seed for deterministic sub-sampling. Default is 0. + hover_metadata + Dictionary mapping metadata keys to display labels for hover tooltips. + Keys are used to look up values in each point's metadata; labels are shown + in the hover text. Pass ``None`` (default) to omit additional hover metadata. + annotation_metadata + Dictionary mapping metadata keys to display labels for model-level + annotations (shown in the text box on the plot). Keys are used to look up + values in the model's metadata dict; labels are shown in the annotation. + Pass ``None`` (default) to omit additional annotation metadata. Returns ------- @@ -687,23 +698,49 @@ def _downsample( global_max = -np.inf processed = {} annotations = [] + annotation_fields = annotation_metadata or {} + hover_fields = hover_metadata or {} + for model in results: data = results[model] ref_vals = np.asarray(data.get("ref", []), dtype=float) pred_vals = np.asarray(data.get("pred", []), dtype=float) meta = data.get("meta") or {} - excluded = meta.get("excluded") - excluded_text = str(excluded) if excluded is not None else "n/a" + + # Extract annotation metadata values (for text box) + annotation_values: list[str] = [] + for meta_key in annotation_fields: + meta_raw = meta.get(meta_key) + annotation_values.append( + "n/a" if meta_raw is None else str(meta_raw) + ) + + # Extract hover metadata values (for tooltips) + hover_values: list[str] = [] + for meta_key in hover_fields: + meta_raw = meta.get(meta_key) + hover_values.append("n/a" if meta_raw is None else str(meta_raw)) + if ref_vals.size == 0 or pred_vals.size == 0: sampled = ([], [], []) else: sampled = _downsample(ref_vals, pred_vals) global_min = min(global_min, ref_vals.min(), pred_vals.min()) global_max = max(global_max, ref_vals.max(), pred_vals.max()) - # Top left corner annotation for each model with exclusion info + + # Build annotation text from annotation metadata + summary_text = "" + if annotation_fields: + summary_text = " | ".join( + f"{label}: {value}" + for value, label in zip( + annotation_values, annotation_fields.values(), strict=True + ) + ) annotations.append( { - "text": f"{model} | Excluded: {excluded_text}", + "text": f"{model}" + + (f" | {summary_text}" if summary_text else ""), "xref": "paper", "yref": "paper", "x": 0.02, @@ -717,7 +754,7 @@ def _downsample( processed[model] = { "samples": sampled, "counts": len(ref_vals), - "meta": excluded_text, + "meta": hover_values if hover_fields else None, } if not np.isfinite(global_min) or not np.isfinite(global_max): @@ -730,12 +767,15 @@ def _downsample( line_end = global_max + padding fig = go.Figure() - hovertemplate = ( - "Reference: %{x:.3f}
" - "Predicted: %{y:.3f}
" - "Density: %{customdata[0]:.0f}
" - "Excluded: %{meta[0]}" - ) + hover_lines = [ + "Reference: %{x:.3f}", + "Predicted: %{y:.3f}", + "Density: %{customdata[0]:.0f}", + ] + if hover_fields: + for idx, label in enumerate(hover_fields.values()): + hover_lines.append(f"{label}: %{{meta[{idx}]}}") + hovertemplate = "
".join(hover_lines) + "" for idx, model in enumerate(results): sample_x, sample_y, density = processed[model]["samples"] @@ -756,7 +796,7 @@ def _downsample( customdata=np.array(density, dtype=float)[:, None] if density else None, - meta=[processed[model]["meta"]], + meta=processed[model]["meta"], hovertemplate=hovertemplate, ) ) @@ -783,7 +823,7 @@ def _downsample( title={"text": title} if title else None, xaxis={"title": {"text": x_label}}, yaxis={"title": {"text": y_label}}, - annotations=[annotations[0]], + annotations=[annotations[0]] if annotations else [], meta=layout_meta, showlegend=True, legend_title_text="Model",