diff --git a/example-notebooks/binary-classifier/classifier_bin.ipynb b/example-notebooks/binary-classifier/classifier_bin.ipynb index fb59d24b..e53f42e5 100644 --- a/example-notebooks/binary-classifier/classifier_bin.ipynb +++ b/example-notebooks/binary-classifier/classifier_bin.ipynb @@ -879,6 +879,61 @@ "sm.ExploreAnalyticsTable()" ] }, + { + "cell_type": "markdown", + "id": "2723b186", + "metadata": {}, + "source": [ + "### Threshold Specific Aggregation" + ] + }, + { + "cell_type": "markdown", + "id": "bca1c81b", + "metadata": {}, + "source": [ + "#### ℹ Info" + ] + }, + { + "cell_type": "markdown", + "id": "0528922a", + "metadata": {}, + "source": [ + "\n", + "This section provides a table for exploring threshold-specific aggregation methods \n", + "(e.g., `first_above_threshold`). \n", + "\n", + "Unlike the standard *Analytics Table*, which summarizes performance metrics across\n", + "multiple thresholds, the *Threshold Aggregation Table* focuses on a **single specified threshold**\n", + "and applies the selected aggregation method before computing summary statistics.\n", + "\n", + "Use this tool to:\n", + "- Inspect how aggregations like `first_above_threshold` affect model results.\n", + "- Compare aggregated outcomes across different scores and targets.\n", + "- View summarized metrics (e.g., Sensitivity, Specificity, PPV, etc.) for the\n", + "aggregated data.\n", + "- Group the results by *Score* or *Target* and optionally combine results per context." + ] + }, + { + "cell_type": "markdown", + "id": "f08689a7", + "metadata": {}, + "source": [ + "#### Visuals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8bfbfe1", + "metadata": {}, + "outputs": [], + "source": [ + "sm.ExploreThresholdAggregationTable()" + ] + }, { "cell_type": "markdown", "id": "953be6a9", @@ -1135,7 +1190,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.12.10" }, "toc-autonumbering": false, "toc-showcode": false, diff --git a/src/seismometer/api/reports.py b/src/seismometer/api/reports.py index 49fca303..08859aa5 100644 --- a/src/seismometer/api/reports.py +++ b/src/seismometer/api/reports.py @@ -9,6 +9,7 @@ from seismometer.seismogram import Seismogram from seismometer.table.analytics_table import ExploreBinaryModelAnalytics from seismometer.table.fairness import ExploreBinaryModelFairness +from seismometer.table.threshold_aggregation import ExploreThresholdAggregation logger = logging.getLogger("seismometer") @@ -42,6 +43,22 @@ def __init__(self): super().__init__() +@export +class ExploreThresholdAggregationTable(ExploreThresholdAggregation): + """ + Exploration widget for threshold-specific entity-level aggregation. + + Applies a fixed threshold and aggregation strategy (for example, ``first_above_threshold``) + and generates an AnalyticsTable-style summary table showing the aggregated results. + """ + + def __init__(self): + """ + Passes the plot function to the superclass. + """ + super().__init__(title="Threshold Aggregation Table") + + @export class ExploreOrdinalMetrics(ExploreCategoricalPlots): """ diff --git a/src/seismometer/data/binary_performance.py b/src/seismometer/data/binary_performance.py index b5175d4b..86d94b3c 100644 --- a/src/seismometer/data/binary_performance.py +++ b/src/seismometer/data/binary_performance.py @@ -131,6 +131,7 @@ def generate_analytics_data( metrics_to_display: Optional[List[str]] = None, decimals: int = 3, censor_threshold: int = 10, + aggregation_method: Optional[str] = None, ) -> Optional[pd.DataFrame]: """ Generates a DataFrame containing calculated statistics for each combination of scores and targets. @@ -158,6 +159,10 @@ def generate_analytics_data( The number of decimal places for rounding numerical results, by default 3. censor_threshold : int, optional Minimum rows required to generate analytics data, by default 10. + aggregation_method : Optional[str], optional + If provided, indicates that the table is being used to summarize threshold-specific aggregation results. + This parameter is not used directly in this function, but it can be useful for customizing the title + or other aspects of the table when it is part of a threshold aggregation analysis, by default None. Returns ------- @@ -188,7 +193,7 @@ def generate_analytics_data( score=score, ref_time=sg.predict_time, ref_event=target, - aggregation_method=sg.event_aggregation_method(target), + aggregation_method=aggregation_method or sg.event_aggregation_method(target), ) if per_context else data diff --git a/src/seismometer/data/pandas_helpers.py b/src/seismometer/data/pandas_helpers.py index d3574124..a2b89fa8 100644 --- a/src/seismometer/data/pandas_helpers.py +++ b/src/seismometer/data/pandas_helpers.py @@ -415,7 +415,9 @@ def _merge_with_strategy( return pd.merge(predictions, one_event_filtered, on=pks, how="left") -def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame: +def max_aggregation( + df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None +) -> pd.DataFrame: """ Aggregates the DataFrame by selecting the maximum score value. @@ -431,6 +433,8 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, The column name containing the time to consider, by default None. ref_event : Optional[str], optional The column name containing the event to consider, by default None. + threshold : Optional[float], optional + Score threshold to compare against, by default None. Returns ------- @@ -446,7 +450,9 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, return df.drop_duplicates(subset=pks) -def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame: +def min_aggregation( + df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None +) -> pd.DataFrame: """ Aggregates the DataFrame by selecting the minimum score value. @@ -462,6 +468,8 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, The column name containing the time to consider, by default None. ref_event : Optional[str], optional The column name containing the event to consider, by default None. + threshold : Optional[float], optional + Score threshold to compare against, by default None. Returns ------- @@ -477,7 +485,9 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, return df.drop_duplicates(subset=pks) -def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame: +def first_aggregation( + df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None +) -> pd.DataFrame: """ Aggregates the DataFrame by selecting the first occurrence based on event time. @@ -493,6 +503,8 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st The column name containing the time to consider, by default None. ref_event : Optional[str], optional The column name containing the event to consider, by default None. + threshold : Optional[float], optional + Score threshold to compare against, by default None. Returns ------- @@ -508,7 +520,51 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st return df.drop_duplicates(subset=pks) -def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame: +def first_above_threshold_aggregation( + df: pd.DataFrame, + pks: list[str], + score: str, + ref_time: Optional[str], + ref_event: Optional[str], + threshold: float, +) -> pd.DataFrame: + """ + Aggregates by selecting the first prediction with a score above the given threshold. + + Parameters + ---------- + df : pd.DataFrame + The dataframe to aggregate. + pks : list[str] + Keys to group by. + score : str + Score column name. + ref_time : str + Time reference column name. + ref_event : str + Not used here but retained for API consistency. + threshold : float + Score threshold to compare against. + + Returns + ------- + pd.DataFrame + Aggregated dataframe with first above-threshold score per group. + """ + ref_score = _resolve_score_col(df, score) + if ref_time is None: + raise ValueError("ref_time is required for first_above_threshold aggregation") + + reference_time = _resolve_time_col(df, ref_time) + df = df[df[ref_score] > threshold] + df = df[df[reference_time].notna()] + df = df.sort_values(by=reference_time) + return df.drop_duplicates(subset=pks) + + +def last_aggregation( + df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None +) -> pd.DataFrame: """ Aggregates the DataFrame by selecting the last occurrence based on event time. @@ -524,6 +580,8 @@ def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str The column name containing the time to consider, by default None. ref_event : Optional[str], optional The column name containing the event to consider, by default None. + threshold : Optional[float], optional + Score threshold to compare against, by default None. Returns ------- @@ -546,6 +604,7 @@ def event_score( ref_time: Optional[str] = None, ref_event: Optional[str] = None, aggregation_method: str = "max", + threshold: Optional[float] = None, ) -> pd.DataFrame: """ Reduces a dataframe of all predictions to a single row of significance; such as the max or most recent value for @@ -573,6 +632,8 @@ def event_score( the aggregation_method. aggregation_method : str, optional A string describing the method to select a value, by default 'max'. + threshold : Optional[float], optional + Score threshold to compare against, by default None. Returns ------- @@ -590,12 +651,13 @@ def event_score( "min": min_aggregation, "first": first_aggregation, "last": last_aggregation, + "first_above_threshold": first_above_threshold_aggregation, } if aggregation_method not in aggregation_methods: raise ValueError(f"Unknown aggregation method: {aggregation_method}") - df = aggregation_methods[aggregation_method](merged_frame, pks, score, ref_time, ref_event) + df = aggregation_methods[aggregation_method](merged_frame, pks, score, ref_time, ref_event, threshold) return df.loc[~np.isnan(df.index)] diff --git a/src/seismometer/table/analytics_table.py b/src/seismometer/table/analytics_table.py index 604b7116..1aef4a17 100644 --- a/src/seismometer/table/analytics_table.py +++ b/src/seismometer/table/analytics_table.py @@ -59,6 +59,7 @@ def __init__( statistics_data: Optional[pd.DataFrame] = None, per_context: bool = False, censor_threshold: int = 10, + aggregation_method: Optional[str] = None, ): """ Initializes the AnalyticsTable object with the necessary data and parameters. @@ -94,6 +95,10 @@ def __init__( censor_threshold : int, optional Minimum number of rows required in the cohort data to enable the generation of an analytics table, by default 10. + aggregation_method : Optional[str], optional + If provided, indicates that the table is being used to summarize threshold-specific aggregation results. + This parameter is not used directly in this class, but it can be useful for customizing the title + or other aspects of the table when it is part of a threshold aggregation analysis, by default None. Raises ------ @@ -132,6 +137,7 @@ def __init__( self._initializing = False self.per_context = per_context self.censor_threshold = censor_threshold + self.aggregation_method = aggregation_method def _validate_df_statistics_data(self): if not self._initializing: # Skip validation during initial setup @@ -376,6 +382,7 @@ def _generate_table_data(self) -> Optional[pd.DataFrame]: metrics_to_display=self.metrics_to_display, decimals=self.decimals, censor_threshold=self.censor_threshold, + aggregation_method=self.aggregation_method, ) if data is None: return None diff --git a/src/seismometer/table/threshold_aggregation.py b/src/seismometer/table/threshold_aggregation.py new file mode 100644 index 00000000..b0676ea4 --- /dev/null +++ b/src/seismometer/table/threshold_aggregation.py @@ -0,0 +1,350 @@ +""" +Threshold Aggregation Exploration Widget + +Provides an interactive interface for exploring threshold-specific aggregation methods +(e.g., 'first_above_threshold'). Generates a formatted AnalyticsTable-style summary +after applying the selected aggregation to model predictions. +""" + +from typing import Any, Optional + +import traitlets +from ipywidgets import Dropdown, GridBox, Layout, VBox + +from seismometer.controls.explore import ExplorationWidget, _combine_scores_checkbox +from seismometer.controls.selection import MultiselectDropdownWidget, MultiSelectionListWidget +from seismometer.controls.styles import BOX_GRID_LAYOUT, html_title +from seismometer.controls.thresholds import MonotonicProbabilitySliderListWidget +from seismometer.data import pandas_helpers as pdh +from seismometer.data.binary_performance import GENERATED_COLUMNS +from seismometer.data.filter import filter_rule_from_cohort_dictionary +from seismometer.data.performance import THRESHOLD +from seismometer.table.analytics_table import AnalyticsTable + +# region Options Widget --------------------------------------------------------- + + +class ThresholdAggregationOptionsWidget(VBox, traitlets.HasTraits): + """ + Widget for selecting options for threshold-specific aggregation. + + Provides controls to select a target, score, aggregation method, and a fixed threshold. + """ + + value = traitlets.Dict(help="The selected values for the threshold aggregation options.") + + def __init__( + self, + target_cols: tuple[str], + score_cols: tuple[str], + cohort_dict: Optional[dict[str, tuple[Any]]] = None, + *, + aggregation_methods: Optional[tuple[str]] = None, + metrics_to_display: Optional[tuple[str]] = None, + ): + """ + Initializes the threshold aggregation options widget. + + Parameters + ---------- + target_cols : tuple[str] + List of target columns to select from. + score_cols : tuple[str] + List of model score columns to select from. + cohort_dict : dict[str, tuple[Any]], optional + Dictionary of cohort columns and values for filtering, by default None. + aggregation_methods : tuple[str], optional + Supported threshold-based aggregation methods, by default: + ('first_above_threshold') + """ + from seismometer.seismogram import Seismogram + + sg = Seismogram() + aggregation_methods = aggregation_methods or ("first_above_threshold",) + metrics_to_display = metrics_to_display or GENERATED_COLUMNS + + self.title = html_title("Threshold Aggregation Options") + + # Cohort Filter + self._cohort_dict = MultiSelectionListWidget(cohort_dict or sg.available_cohort_groups, title="Cohort Filter") + + self._target_cols = MultiselectDropdownWidget( + options=tuple(map(pdh.event_name, target_cols)), + value=target_cols[:2] if len(target_cols) > 1 else target_cols, + title="Targets", + ) + + self._score_cols = MultiselectDropdownWidget( + options=score_cols, + value=score_cols[:2] if len(score_cols) > 1 else score_cols, + title="Scores", + ) + + # Aggregation Method Selector + self._aggregation_method = Dropdown( + options=aggregation_methods, + value=aggregation_methods[0], + description="Aggregation", + style={"description_width": "min-content"}, + layout=Layout(width="250px"), + ) + + # Threshold Slider + self._threshold = MonotonicProbabilitySliderListWidget( + names=("Threshold",), + value=(0.5,), + ascending=False, + ) + + # Metrics to display + self._metrics_to_display = MultiselectDropdownWidget( + options=GENERATED_COLUMNS, + value=metrics_to_display, + title="Performance Metrics to Display", + ) + + # Group By + self._group_by = Dropdown( + options=["Score", "Target"], + value="Score", + description="Group By", + style={"description_width": "min-content"}, + layout=Layout(width="250px"), + ) + + # Combine Scores Checkbox + self.per_context_checkbox = _combine_scores_checkbox(per_context=False) + + # Observe all widgets for updates + for w in [ + self._cohort_dict, + self._target_cols, + self._score_cols, + self._aggregation_method, + self._threshold, + self._metrics_to_display, + self._group_by, + self.per_context_checkbox, + ]: + w.observe(self._on_value_changed, names="value") + + # Layout + grid_layout = Layout( + width="100%", grid_template_columns="repeat(3, 1fr)", justify_items="flex-start", grid_gap="10px" + ) + + # Create a 3-column grid of main controls + grid_box = GridBox( + children=[ + self._target_cols, + self._score_cols, + self._metrics_to_display, + self._aggregation_method, + self._threshold, + self._group_by, + self.per_context_checkbox, + ], + layout=grid_layout, + ) + + # Combine with title and cohort filter above + grid_with_title = VBox( + children=[self.title, grid_box], + layout=Layout(align_items="flex-start"), + ) + + super().__init__(children=[self._cohort_dict, grid_with_title], layout=BOX_GRID_LAYOUT) + self._on_value_changed() + self._disabled = False + + # region Properties + + @property + def disabled(self) -> bool: + return self._disabled + + @disabled.setter + def disabled(self, value: bool): + self._disabled = value + self._cohort_dict.disabled = value + self._target_cols.disabled = value + self._score_cols.disabled = value + self._aggregation_method.disabled = value + self._threshold.disabled = value + self._metrics_to_display.disabled = value + self._group_by.disabled = value + self.per_context_checkbox.disabled = value + + def _on_value_changed(self, change=None): + """Update internal dictionary when any option changes.""" + self.value = { + "cohort_dict": self._cohort_dict.value, + "target_cols": self._target_cols.value, + "score_cols": self._score_cols.value, + "aggregation_method": self._aggregation_method.value, + "threshold": list(self._threshold.value.values())[0], + "metrics_to_display": self._metrics_to_display.value, + "group_by": self._group_by.value, + "group_scores": self.per_context_checkbox.value, + } + + @property + def cohort_dict(self) -> dict[str, tuple[Any]]: + return self._cohort_dict.value + + @property + def target_cols(self) -> tuple[str]: + return self._target_cols.value + + @property + def score_cols(self) -> tuple[str]: + return self._score_cols.value + + @property + def aggregation_method(self) -> str: + return self._aggregation_method.value + + @property + def metrics_to_display(self): + return self._metrics_to_display.value + + @property + def group_by(self) -> str: + return self._group_by.value + + @property + def threshold(self) -> float: + return list(self._threshold.value.values())[0] + + @property + def group_scores(self) -> bool: + return self.per_context_checkbox.value + + # endregion + + +# endregion +# region Explore Widget --------------------------------------------------------- + + +class ExploreThresholdAggregation(ExplorationWidget): + """ + Exploration widget for threshold-specific aggregation methods. + + Applies a fixed threshold and aggregation strategy (e.g., 'first_above_threshold'), + then generates a formatted AnalyticsTable-style summary of results. + """ + + def __init__(self, title: Optional[str] = None): + """ + Initializes the threshold aggregation exploration widget. + + Parameters + ---------- + title : str, optional + The title displayed above the control, by default "Threshold Aggregation Explorer". + """ + from seismometer.seismogram import Seismogram + + sg = Seismogram() + self.title = title or "Threshold Aggregation Explorer" + + super().__init__( + title=self.title, + option_widget=ThresholdAggregationOptionsWidget( + target_cols=tuple(map(pdh.event_name, sg.get_binary_targets())), + score_cols=sg.output_list, + cohort_dict=sg.available_cohort_groups, + ), + plot_function=self._plot_threshold_aggregation, + initial_plot=False, + ) + + def _plot_threshold_aggregation( + self, + cohort_dict: dict[str, tuple[Any]], + target_cols: list[str], + score_cols: list[str], + aggregation_method: str, + threshold: float, + metrics_to_display: list[str], + group_by: str, + per_context: bool, + ): + """ + Applies a threshold-based aggregation and renders an AnalyticsTable-style summary. + + Parameters + ---------- + cohort_dict : dict[str, tuple[Any]] + Cohort filter to apply before aggregation. + target_cols : list[str] + Target columns to aggregate over. + score_cols : list[str] + Score columns used for thresholding. + aggregation_method : str + Aggregation strategy to apply (e.g., 'first_above_threshold'). + threshold : float + The score threshold to use. + metrics_to_display : list[str] + Metrics to include in the output table. + group_by : str + Whether to group by "Score" or "Target". + per_context : bool + Whether to aggregate per context instead of globally. + + Returns + ------- + HTML + Rendered AnalyticsTable summary for the aggregated data. + """ + from seismometer.seismogram import Seismogram + + sg = Seismogram() + df = sg.dataframe + + if cohort_dict: + df = filter_rule_from_cohort_dictionary(cohort_dict).filter(df) + + # Build AnalyticsTable-style summary using the existing class + summary_table = AnalyticsTable( + score_columns=score_cols, + target_columns=target_cols, + metric=THRESHOLD, + metric_values=[threshold], + metrics_to_display=metrics_to_display, + title="Threshold Specific Aggregation", + top_level=group_by, + cohort_dict=cohort_dict, + per_context=per_context, + censor_threshold=sg.censor_threshold, + aggregation_method=aggregation_method, + ) + + return summary_table.analytics_table() + + def generate_plot_args(self) -> tuple[tuple, dict]: + """ + Generates arguments for the plot_function. + + Returns + ------- + tuple[tuple, dict] + Positional and keyword arguments to be passed to the plot_function. + """ + opts = self.option_widget + args = ( + opts.cohort_dict, + tuple(map(pdh.event_value, opts.target_cols)), + opts.score_cols, + opts.aggregation_method, + opts.threshold, + opts.metrics_to_display, + opts.group_by, + opts.group_scores, + ) + kwargs = {} + return args, kwargs + + +# endregion diff --git a/tests/controls/test_explore.py b/tests/controls/test_explore.py index a5bb0cb5..805d3e4f 100644 --- a/tests/controls/test_explore.py +++ b/tests/controls/test_explore.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, Mock, patch import ipywidgets +import pandas as pd import pytest import seismometer.controls.explore as undertest @@ -19,6 +20,7 @@ FairnessOptionsWidget, binary_metrics_fairness_table, ) +from seismometer.table.threshold_aggregation import ExploreThresholdAggregation, ThresholdAggregationOptionsWidget # region Test Base Classes @@ -1604,4 +1606,226 @@ def test_generate_plot_args_returns_expected_values(self, monkeypatch): assert kwargs == {"per_context": True} +class TestThresholdAggregationOptionsWidget: + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_init(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1", "T2"] + fake_seismo.output_list = ["S1", "S2"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ThresholdAggregationOptionsWidget( + target_cols=("T1", "T2"), + score_cols=("S1", "S2"), + cohort_dict=fake_seismo.available_cohort_groups, + ) + + assert isinstance(widget, ipywidgets.VBox) + assert widget._target_cols.value == ("T1", "T2") + assert widget._score_cols.value == ("S1", "S2") + assert widget._aggregation_method.value == "first_above_threshold" + assert isinstance(widget.threshold, float) + assert widget.group_by in ("Score", "Target") + assert widget.group_scores in (True, False) + assert isinstance(widget.value, dict) + assert widget.value["aggregation_method"] == "first_above_threshold" + + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_disabled_property(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1"] + fake_seismo.output_list = ["S1"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ThresholdAggregationOptionsWidget( + target_cols=("T1",), + score_cols=("S1",), + cohort_dict=fake_seismo.available_cohort_groups, + ) + + widget.disabled = True + for child in [ + widget._cohort_dict, + widget._target_cols, + widget._score_cols, + widget._aggregation_method, + widget._threshold, + widget._metrics_to_display, + widget._group_by, + widget.per_context_checkbox, + ]: + assert child.disabled is True + + widget.disabled = False + for child in [ + widget._cohort_dict, + widget._target_cols, + widget._score_cols, + widget._aggregation_method, + widget._threshold, + widget._metrics_to_display, + widget._group_by, + widget.per_context_checkbox, + ]: + assert child.disabled is False + + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_on_value_changed(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1"] + fake_seismo.output_list = ["S1"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ThresholdAggregationOptionsWidget( + target_cols=("T1",), + score_cols=("S1",), + cohort_dict=fake_seismo.available_cohort_groups, + ) + + # Simulate user changing values + widget._cohort_dict.value = {"C1": ["A"]} + widget._target_cols.value = ("T1",) + widget._score_cols.value = ("S1",) + widget._aggregation_method.value = "first_above_threshold" + widget._threshold.value = {"Threshold": 0.75} + widget._metrics_to_display.value = ("AUROC",) + widget._group_by.value = "Target" + widget.per_context_checkbox.value = True + + widget._on_value_changed() + + val = widget.value + assert val["cohort_dict"] == {"C1": ["A"]} + assert val["target_cols"] == ("T1",) + assert val["score_cols"] == ("S1",) + assert val["aggregation_method"] == "first_above_threshold" + assert val["threshold"] == 0.75 + assert val["metrics_to_display"] == ("AUROC",) + assert val["group_by"] == "Target" + assert val["group_scores"] is True + + @patch.object(seismogram, "Seismogram", return_value=Mock()) + @pytest.mark.parametrize("group_by", ["Score", "Target"]) + def test_group_by_and_threshold_behavior(self, mock_seismo, group_by): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1"] + fake_seismo.output_list = ["S1"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ThresholdAggregationOptionsWidget( + target_cols=("T1",), + score_cols=("S1",), + cohort_dict=fake_seismo.available_cohort_groups, + ) + + widget._group_by.value = group_by + widget._threshold.value = {"Threshold": 0.3} + widget._on_value_changed() + + assert widget.group_by == group_by + assert widget.threshold == 0.3 + + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_property_accessors(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1", "T2"] + fake_seismo.output_list = ["S1", "S2"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ThresholdAggregationOptionsWidget( + target_cols=("T1", "T2"), + score_cols=("S1", "S2"), + cohort_dict=fake_seismo.available_cohort_groups, + ) + + assert isinstance(widget.cohort_dict, dict) + assert isinstance(widget.target_cols, tuple) + assert isinstance(widget.score_cols, tuple) + assert isinstance(widget.aggregation_method, str) + assert isinstance(widget.metrics_to_display, tuple) + assert widget.group_by in ("Score", "Target") + assert isinstance(widget.threshold, float) + assert isinstance(widget.group_scores, bool) + + +class TestExploreThresholdAggregation: + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_init(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1", "T2"] + fake_seismo.output_list = ["S1", "S2"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ExploreThresholdAggregation(title="Unit Test Title") + + assert widget.title == "Unit Test Title" + assert isinstance(widget.option_widget, ThresholdAggregationOptionsWidget) + assert widget.plot_function == widget._plot_threshold_aggregation + + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_plot_threshold_aggregation_builds_expected_table(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.dataframe = pd.DataFrame({"C1": ["A", "B", "A"], "score1": [0.1, 0.2, 0.3], "target1": [0, 1, 0]}) + fake_seismo.censor_threshold = 10 + fake_seismo.get_binary_targets.return_value = ["T1", "T2"] + fake_seismo.output_list = ["S1", "S2"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + mock_table = Mock(analytics_table=lambda: "HTML_TABLE") + with patch( + "seismometer.table.threshold_aggregation.AnalyticsTable", return_value=mock_table + ) as mock_analytics: + widget = ExploreThresholdAggregation(title="Unit Test Title") + result = widget._plot_threshold_aggregation( + cohort_dict={"C1": ["A"]}, + target_cols=["T1"], + score_cols=["S1"], + aggregation_method="first_above_threshold", + threshold=0.5, + metrics_to_display=["AUROC"], + group_by="Score", + per_context=False, + ) + + assert result == "HTML_TABLE" + mock_analytics.assert_called_once() + args, kwargs = mock_analytics.call_args + assert kwargs["aggregation_method"] == "first_above_threshold" + assert kwargs["metric_values"] == [0.5] + assert kwargs["metric"] == "Threshold" + + @patch.object(seismogram, "Seismogram", return_value=Mock()) + def test_generate_plot_args_returns_expected_values(self, mock_seismo): + fake_seismo = mock_seismo() + fake_seismo.get_binary_targets.return_value = ["T1", "T2"] + fake_seismo.output_list = ["S1", "S2"] + fake_seismo.available_cohort_groups = {"C1": ["A", "B"]} + + widget = ExploreThresholdAggregation(title="Unit Test Title") + opt = widget.option_widget + + opt._cohort_dict.value = {"C1": ["A"]} + opt._target_cols.value = ("T1",) + opt._score_cols.value = ("S1",) + opt._aggregation_method.value = "first_above_threshold" + opt._threshold.value = {"Threshold": 0.7} + opt._metrics_to_display.value = ("AUROC",) + opt._group_by.value = "Target" + opt.per_context_checkbox.value = True + opt._on_value_changed() + + args, kwargs = widget.generate_plot_args() + + assert args == ( + {"C1": ["A"]}, + ("T1_Value",), + ("S1",), + "first_above_threshold", + 0.7, + ("AUROC",), + "Target", + True, + ) + assert kwargs == {} + + # endregion diff --git a/tests/data/test_pandas_helpers.py b/tests/data/test_pandas_helpers.py index fff311c2..c6893877 100644 --- a/tests/data/test_pandas_helpers.py +++ b/tests/data/test_pandas_helpers.py @@ -510,6 +510,96 @@ def test_get_model_scores_bypass_when_not_per_context(self): ) pd.testing.assert_frame_equal(result, df) + @pytest.mark.parametrize( + "threshold,expected_scores", + [ + (0.5, {1: 0.6, 2: 0.8}), # both entities exceed threshold + (0.7, {1: 0.9, 2: 0.8}), # only later scores pass + (0.95, {}), # no entity meets threshold → empty + ], + ) + def test_event_score_first_above_threshold_various_thresholds(self, threshold, expected_scores): + """Check that event_score(first_above_threshold) correctly selects the first score above threshold.""" + now = pd.Timestamp("2024-01-01 00:00:00") + df = pd.DataFrame( + { + "Id": [1, 1, 1, 2, 2], + "Score": [0.2, 0.6, 0.9, 0.3, 0.8], + "EventName_Time": [now + pd.Timedelta(hours=h) for h in [0, 1, 2, 0, 1]], + "EventName_Value": [1, 1, 1, 1, 1], + } + ) + + result = undertest.event_score( + df, + pks=["Id"], + score="Score", + ref_time="EventName_Time", + ref_event="EventName", + aggregation_method="first_above_threshold", + threshold=threshold, + ) + + if not expected_scores: + assert result.empty + else: + assert set(result["Id"]) == set(expected_scores) + for i, val in expected_scores.items(): + assert pytest.approx(result.loc[result["Id"] == i, "Score"].iloc[0]) == val + + @pytest.mark.parametrize( + "has_time_col, ref_time_arg, expected_error", + [ + # Case 1: Missing column entirely, ref_time points to a non-existent column + (False, "EventName_Time", "Reference time column EventName_Time not found"), + # Case 2: Column exists, but ref_time=None + (True, None, "ref_time is required"), + ], + ) + def test_first_above_threshold_raises_when_ref_time_missing(self, has_time_col, ref_time_arg, expected_error): + """Ensure first_above_threshold_aggregation raises when ref_time is missing or argument is None.""" + now = pd.Timestamp("2024-01-01 00:00:00") + df = pd.DataFrame({"Id": [1, 1], "Score": [0.3, 0.8]}) + + # Add the time column only for the second test case + if has_time_col: + df["EventName_Time"] = [now, now + pd.Timedelta(hours=1)] + + with pytest.raises(ValueError, match=expected_error): + undertest.first_above_threshold_aggregation( + df, + pks=["Id"], + score="Score", + ref_time=ref_time_arg, + ref_event="EventName", + threshold=0.5, + ) + + def test_first_above_threshold_ignores_missing_time_rows(self): + """Rows with NaT in ref_time should be dropped before selecting first above threshold.""" + now = pd.Timestamp("2024-01-01 00:00:00") + df = pd.DataFrame( + { + "Id": [1, 1, 1], + "Score": [0.4, 0.8, 0.9], + "EventName_Time": [now, pd.NaT, now + pd.Timedelta(hours=2)], + } + ) + + result = undertest.first_above_threshold_aggregation( + df, + pks=["Id"], + score="Score", + ref_time="EventName_Time", + ref_event="EventName", + threshold=0.5, + ) + + # should keep only rows with valid times and score > 0.5 + assert not result.empty + assert result["EventName_Time"].notna().all() + assert (result["Score"] > 0.5).all() + class TestMergeEventCounts: def test_skips_time_filter_when_window_none(self, base_counts_data): diff --git a/tests/table/test_threshold_aggregation.py b/tests/table/test_threshold_aggregation.py new file mode 100644 index 00000000..02138771 --- /dev/null +++ b/tests/table/test_threshold_aggregation.py @@ -0,0 +1,151 @@ +from unittest.mock import Mock + +import pandas as pd +import pytest + +from seismometer.configuration import ConfigProvider +from seismometer.configuration.model import Cohort, Event +from seismometer.data.loader import SeismogramLoader +from seismometer.data.performance import THRESHOLD +from seismometer.seismogram import Seismogram +from seismometer.table.analytics_table import AnalyticsTable + + +def get_test_config(tmp_path): + mock_config = Mock(autospec=ConfigProvider) + mock_config.output_dir.return_value + mock_config.events = { + "event1": Event(source="event1", display_name="event1", window_hr=1), + "event2": Event(source="event2", display_name="event2", window_hr=2, aggregation_method="min"), + "event3": Event(source="event3", display_name="event3", window_hr=1), + } + mock_config.target = "event1" + mock_config.entity_keys = ["entity"] + mock_config.predict_time = "time" + mock_config.cohorts = [Cohort(source=name) for name in ["cohort1"]] + mock_config.features = ["one"] + mock_config.config_dir = tmp_path / "config" + mock_config.censor_min_count = 0 + mock_config.targets = ["event1", "event2", "event3"] + mock_config.output_list = ["prediction", "score1", "score2"] + + return mock_config + + +def get_test_loader(config): + mock_loader = Mock(autospec=SeismogramLoader) + mock_loader.config = config + + return mock_loader + + +def get_test_data(): + return pd.DataFrame( + { + "entity": ["A", "A", "B", "C"], + "prediction": [1, 2, 3, 4], + "time": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"], + "event1_Value": [0, 1, 0, 1], + "event1_Time": ["2022-01-01", "2022-01-02", "2022-01-03", "2021-12-31"], + "event2_Value": [0, 1, 0, 1], + "event2_Time": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"], + "event3_Value": [0, 2, 5, 1], + "event3_Time": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"], + "cohort1": ["A", "A", "A", "B"], + "score1": [0.1, 0.4, 0.35, 0.8], + "score2": [0.2, 0.5, 0.3, 0.7], + "target1": [0, 1, 0, 1], + "target2": [1, 0, 1, 0], + "target3": [1, 1, 1, 0], + } + ) + + +@pytest.fixture +def fake_seismo(tmp_path): + config = get_test_config(tmp_path) + loader = get_test_loader(config) + sg = Seismogram(config, loader) + sg.dataframe = get_test_data() + sg.available_cohort_groups = {"cohort1": ["A", "B"]} + yield sg + + Seismogram.kill() + + +class TestThresholdAggregationIntegration: + """Integration-level smoke tests for threshold-specific aggregation methods. + + These tests ensure that the new `first_above_threshold` aggregation + mode integrates cleanly with AnalyticsTable and produces valid output. + """ + + @pytest.mark.parametrize( + "score_cols,target_cols,thresholds,group_by", + [ + (["score1"], ["target1"], [0.5], "Score"), + (["score1", "score2"], ["target1", "target2"], [0.5], "Target"), + (["score1"], ["target1"], [0.3, 0.6, 0.9], "Score"), + ], + ) + def test_first_above_threshold_runs(self, fake_seismo, score_cols, target_cols, thresholds, group_by): + """Ensure first_above_threshold works for multiple score/target setups.""" + table = AnalyticsTable( + score_columns=score_cols, + target_columns=target_cols, + metric=THRESHOLD, + metric_values=thresholds, + metrics_to_display=["AUROC", "PPV"], + censor_threshold=1, + cohort_dict={"cohort1": ("A", "B")}, + aggregation_method="first_above_threshold", + top_level=group_by, + ) + + data = table._generate_table_data() + assert isinstance(data, pd.DataFrame) + assert "Score" in data.columns + assert "Target" in data.columns + + html_table = table.analytics_table() + assert html_table is not None + + @pytest.mark.parametrize("censor_threshold,expected_none", [(100, True), (1, False)]) + def test_censor_threshold_behavior(self, fake_seismo, censor_threshold, expected_none): + """Verify censor_threshold filtering behaves consistently.""" + table = AnalyticsTable( + score_columns=["score1"], + target_columns=["target1"], + metric=THRESHOLD, + metric_values=[0.8], + metrics_to_display=["AUROC"], + censor_threshold=censor_threshold, + cohort_dict={"cohort1": ("A", "B")}, + aggregation_method="first_above_threshold", + ) + + data = table._generate_table_data() + if expected_none: + assert data is None + else: + assert isinstance(data, pd.DataFrame) + + @pytest.mark.parametrize("group_by", ["Score", "Target"]) + def test_top_level_grouping_variants(self, fake_seismo, group_by): + """Verify top_level variations work for threshold aggregation.""" + table = AnalyticsTable( + score_columns=["score1", "score2"], + target_columns=["target1", "target2"], + metric=THRESHOLD, + metric_values=[0.7], + metrics_to_display=["AUROC", "Sensitivity"], + censor_threshold=1, + cohort_dict={"cohort1": ("A", "B")}, + aggregation_method="first_above_threshold", + top_level=group_by, + ) + + data = table._generate_table_data() + assert isinstance(data, pd.DataFrame) + html = table.analytics_table() + assert html is not None