From 965cbec9772ab1a8ce5c765ecfcd9c81cef79a16 Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Thu, 29 Jan 2026 22:53:30 +0000 Subject: [PATCH 1/3] generalise density scatter plot to not default to e.g. elastic moduli exluded in the annotations and hover data --- .../elasticity/analyse_elasticity.py | 2 + ml_peg/analysis/utils/decorators.py | 45 +++++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py index 6472f0717..cf3f90b2b 100644 --- a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py +++ b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py @@ -141,6 +141,7 @@ def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | title="Bulk modulus density plot", x_label="Reference bulk modulus / GPa", y_label="Predicted bulk modulus / GPa", + hover_metadata=(("excluded", "Excluded"),), ) def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: """ @@ -165,6 +166,7 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict] title="Shear modulus density plot", x_label="Reference shear modulus / GPa", y_label="Predicted shear modulus / GPa", + hover_metadata=(("excluded", "Excluded"),), ) def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: """ diff --git a/ml_peg/analysis/utils/decorators.py b/ml_peg/analysis/utils/decorators.py index 3641aab51..17bc34a81 100644 --- a/ml_peg/analysis/utils/decorators.py +++ b/ml_peg/analysis/utils/decorators.py @@ -544,6 +544,7 @@ def plot_density_scatter( grid_size: int = 80, max_points_per_cell: int = 5, seed: int = 0, + hover_metadata: tuple[tuple[str, str], ...] | None = None, ) -> Callable: """ Plot density-coloured parity scatter with legend-based model toggling. @@ -571,6 +572,10 @@ def plot_density_scatter( Maximum number of examples plotted per cell to keep renders responsive. seed Seed for deterministic sub-sampling. Default is 0. + hover_metadata + Sequence of ``(meta_key, label)`` pairs to include in hover text/annotations. + Defaults to showing ``("excluded", "Excluded")``; set to ``None`` for no extra + metadata. Returns ------- @@ -687,23 +692,34 @@ def _downsample( global_max = -np.inf processed = {} annotations = [] + metadata_fields = hover_metadata or () for model in results: data = results[model] ref_vals = np.asarray(data.get("ref", []), dtype=float) pred_vals = np.asarray(data.get("pred", []), dtype=float) meta = data.get("meta") or {} - excluded = meta.get("excluded") - excluded_text = str(excluded) if excluded is not None else "n/a" + meta_values: list[str] = [] + for meta_key, _ in metadata_fields: + meta_raw = meta.get(meta_key) + meta_values.append("n/a" if meta_raw is None else str(meta_raw)) if ref_vals.size == 0 or pred_vals.size == 0: sampled = ([], [], []) else: sampled = _downsample(ref_vals, pred_vals) global_min = min(global_min, ref_vals.min(), pred_vals.min()) global_max = max(global_max, ref_vals.max(), pred_vals.max()) - # Top left corner annotation for each model with exclusion info + summary_text = "" + if metadata_fields: + summary_text = " | ".join( + f"{label}: {value}" + for value, (_, label) in zip( + meta_values, metadata_fields, strict=False + ) + ) annotations.append( { - "text": f"{model} | Excluded: {excluded_text}", + "text": f"{model}" + + (f" | {summary_text}" if summary_text else ""), "xref": "paper", "yref": "paper", "x": 0.02, @@ -717,7 +733,7 @@ def _downsample( processed[model] = { "samples": sampled, "counts": len(ref_vals), - "meta": excluded_text, + "meta": meta_values if metadata_fields else None, } if not np.isfinite(global_min) or not np.isfinite(global_max): @@ -730,12 +746,15 @@ def _downsample( line_end = global_max + padding fig = go.Figure() - hovertemplate = ( - "Reference: %{x:.3f}
" - "Predicted: %{y:.3f}
" - "Density: %{customdata[0]:.0f}
" - "Excluded: %{meta[0]}" - ) + hover_lines = [ + "Reference: %{x:.3f}", + "Predicted: %{y:.3f}", + "Density: %{customdata[0]:.0f}", + ] + if metadata_fields: + for idx, (_, label) in enumerate(metadata_fields): + hover_lines.append(f"{label}: %{{meta[{idx}]}}") + hovertemplate = "
".join(hover_lines) + "" for idx, model in enumerate(results): sample_x, sample_y, density = processed[model]["samples"] @@ -756,7 +775,7 @@ def _downsample( customdata=np.array(density, dtype=float)[:, None] if density else None, - meta=[processed[model]["meta"]], + meta=processed[model]["meta"], hovertemplate=hovertemplate, ) ) @@ -783,7 +802,7 @@ def _downsample( title={"text": title} if title else None, xaxis={"title": {"text": x_label}}, yaxis={"title": {"text": y_label}}, - annotations=[annotations[0]], + annotations=[annotations[0]] if annotations else [], meta=layout_meta, showlegend=True, legend_title_text="Model", From 8fa31f4c3c2a4013a55a7eb569c9056eaedb3cd5 Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Mon, 2 Feb 2026 12:06:02 +0000 Subject: [PATCH 2/3] update docstring --- ml_peg/analysis/utils/decorators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml_peg/analysis/utils/decorators.py b/ml_peg/analysis/utils/decorators.py index 17bc34a81..d1497ea35 100644 --- a/ml_peg/analysis/utils/decorators.py +++ b/ml_peg/analysis/utils/decorators.py @@ -573,9 +573,9 @@ def plot_density_scatter( seed Seed for deterministic sub-sampling. Default is 0. hover_metadata - Sequence of ``(meta_key, label)`` pairs to include in hover text/annotations. - Defaults to showing ``("excluded", "Excluded")``; set to ``None`` for no extra - metadata. + Sequence of ``(metadata_key, label)`` pairs to include in hover text/ + annotations. + Pass ``None`` (default) to omit additional metadata. Returns ------- From 254f394bed4cace6c15365a78b48fb8e81a9a92d Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Mon, 2 Feb 2026 19:40:52 +0000 Subject: [PATCH 3/3] change hover metadata to a dict and separate hover and annotation metadata --- .../elasticity/analyse_elasticity.py | 4 +- ml_peg/analysis/utils/decorators.py | 49 +++++++++++++------ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py index cf3f90b2b..b0e07848f 100644 --- a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py +++ b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py @@ -141,7 +141,7 @@ def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | title="Bulk modulus density plot", x_label="Reference bulk modulus / GPa", y_label="Predicted bulk modulus / GPa", - hover_metadata=(("excluded", "Excluded"),), + annotation_metadata={"excluded": "Excluded"}, ) def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: """ @@ -166,7 +166,7 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict] title="Shear modulus density plot", x_label="Reference shear modulus / GPa", y_label="Predicted shear modulus / GPa", - hover_metadata=(("excluded", "Excluded"),), + annotation_metadata={"excluded": "Excluded"}, ) def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: """ diff --git a/ml_peg/analysis/utils/decorators.py b/ml_peg/analysis/utils/decorators.py index d1497ea35..34c166eda 100644 --- a/ml_peg/analysis/utils/decorators.py +++ b/ml_peg/analysis/utils/decorators.py @@ -544,7 +544,8 @@ def plot_density_scatter( grid_size: int = 80, max_points_per_cell: int = 5, seed: int = 0, - hover_metadata: tuple[tuple[str, str], ...] | None = None, + hover_metadata: dict[str, str] | None = None, + annotation_metadata: dict[str, str] | None = None, ) -> Callable: """ Plot density-coloured parity scatter with legend-based model toggling. @@ -573,9 +574,14 @@ def plot_density_scatter( seed Seed for deterministic sub-sampling. Default is 0. hover_metadata - Sequence of ``(metadata_key, label)`` pairs to include in hover text/ - annotations. - Pass ``None`` (default) to omit additional metadata. + Dictionary mapping metadata keys to display labels for hover tooltips. + Keys are used to look up values in each point's metadata; labels are shown + in the hover text. Pass ``None`` (default) to omit additional hover metadata. + annotation_metadata + Dictionary mapping metadata keys to display labels for model-level + annotations (shown in the text box on the plot). Keys are used to look up + values in the model's metadata dict; labels are shown in the annotation. + Pass ``None`` (default) to omit additional annotation metadata. Returns ------- @@ -692,28 +698,43 @@ def _downsample( global_max = -np.inf processed = {} annotations = [] - metadata_fields = hover_metadata or () + annotation_fields = annotation_metadata or {} + hover_fields = hover_metadata or {} + for model in results: data = results[model] ref_vals = np.asarray(data.get("ref", []), dtype=float) pred_vals = np.asarray(data.get("pred", []), dtype=float) meta = data.get("meta") or {} - meta_values: list[str] = [] - for meta_key, _ in metadata_fields: + + # Extract annotation metadata values (for text box) + annotation_values: list[str] = [] + for meta_key in annotation_fields: meta_raw = meta.get(meta_key) - meta_values.append("n/a" if meta_raw is None else str(meta_raw)) + annotation_values.append( + "n/a" if meta_raw is None else str(meta_raw) + ) + + # Extract hover metadata values (for tooltips) + hover_values: list[str] = [] + for meta_key in hover_fields: + meta_raw = meta.get(meta_key) + hover_values.append("n/a" if meta_raw is None else str(meta_raw)) + if ref_vals.size == 0 or pred_vals.size == 0: sampled = ([], [], []) else: sampled = _downsample(ref_vals, pred_vals) global_min = min(global_min, ref_vals.min(), pred_vals.min()) global_max = max(global_max, ref_vals.max(), pred_vals.max()) + + # Build annotation text from annotation metadata summary_text = "" - if metadata_fields: + if annotation_fields: summary_text = " | ".join( f"{label}: {value}" - for value, (_, label) in zip( - meta_values, metadata_fields, strict=False + for value, label in zip( + annotation_values, annotation_fields.values(), strict=True ) ) annotations.append( @@ -733,7 +754,7 @@ def _downsample( processed[model] = { "samples": sampled, "counts": len(ref_vals), - "meta": meta_values if metadata_fields else None, + "meta": hover_values if hover_fields else None, } if not np.isfinite(global_min) or not np.isfinite(global_max): @@ -751,8 +772,8 @@ def _downsample( "Predicted: %{y:.3f}", "Density: %{customdata[0]:.0f}", ] - if metadata_fields: - for idx, (_, label) in enumerate(metadata_fields): + if hover_fields: + for idx, label in enumerate(hover_fields.values()): hover_lines.append(f"{label}: %{{meta[{idx}]}}") hovertemplate = "
".join(hover_lines) + ""