Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion example-notebooks/binary-classifier/classifier_bin.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,61 @@
"sm.ExploreAnalyticsTable()"
]
},
{
"cell_type": "markdown",
"id": "2723b186",
"metadata": {},
"source": [
"### Threshold Specific Aggregation"
]
},
{
"cell_type": "markdown",
"id": "bca1c81b",
"metadata": {},
"source": [
"#### ℹ Info"
]
},
{
"cell_type": "markdown",
"id": "0528922a",
"metadata": {},
"source": [
"\n",
"This section provides a table for exploring threshold-specific aggregation methods \n",
"(e.g., `first_above_threshold`). \n",
"\n",
"Unlike the standard *Analytics Table*, which summarizes performance metrics across\n",
"multiple thresholds, the *Threshold Aggregation Table* focuses on a **single specified threshold**\n",
"and applies the selected aggregation method before computing summary statistics.\n",
"\n",
"Use this tool to:\n",
"- Inspect how aggregations like `first_above_threshold` affect model results.\n",
"- Compare aggregated outcomes across different scores and targets.\n",
"- View summarized metrics (e.g., Sensitivity, Specificity, PPV, etc.) for the\n",
"aggregated data.\n",
"- Group the results by *Score* or *Target* and optionally combine results per context."
]
},
{
"cell_type": "markdown",
"id": "f08689a7",
"metadata": {},
"source": [
"#### Visuals"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8bfbfe1",
"metadata": {},
"outputs": [],
"source": [
"sm.ExploreThresholdAggregationTable()"
]
},
{
"cell_type": "markdown",
"id": "953be6a9",
Expand Down Expand Up @@ -1135,7 +1190,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.12.10"
},
"toc-autonumbering": false,
"toc-showcode": false,
Expand Down
17 changes: 17 additions & 0 deletions src/seismometer/api/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from seismometer.seismogram import Seismogram
from seismometer.table.analytics_table import ExploreBinaryModelAnalytics
from seismometer.table.fairness import ExploreBinaryModelFairness
from seismometer.table.threshold_aggregation import ExploreThresholdAggregation

logger = logging.getLogger("seismometer")

Expand Down Expand Up @@ -42,6 +43,22 @@ def __init__(self):
super().__init__()


@export
class ExploreThresholdAggregationTable(ExploreThresholdAggregation):
"""
Exploration widget for threshold-specific entity-level aggregation.

Applies a fixed threshold and aggregation strategy (for example, ``first_above_threshold``)
and generates an AnalyticsTable-style summary table showing the aggregated results.
"""

def __init__(self):
"""
Passes the plot function to the superclass.
"""
super().__init__(title="Threshold Aggregation Table")


@export
class ExploreOrdinalMetrics(ExploreCategoricalPlots):
"""
Expand Down
7 changes: 6 additions & 1 deletion src/seismometer/data/binary_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def generate_analytics_data(
metrics_to_display: Optional[List[str]] = None,
decimals: int = 3,
censor_threshold: int = 10,
aggregation_method: Optional[str] = None,
) -> Optional[pd.DataFrame]:
"""
Generates a DataFrame containing calculated statistics for each combination of scores and targets.
Expand Down Expand Up @@ -158,6 +159,10 @@ def generate_analytics_data(
The number of decimal places for rounding numerical results, by default 3.
censor_threshold : int, optional
Minimum rows required to generate analytics data, by default 10.
aggregation_method : Optional[str], optional
If provided, indicates that the table is being used to summarize threshold-specific aggregation results.
This parameter is not used directly in this function, but it can be useful for customizing the title
or other aspects of the table when it is part of a threshold aggregation analysis, by default None.

Returns
-------
Expand Down Expand Up @@ -188,7 +193,7 @@ def generate_analytics_data(
score=score,
ref_time=sg.predict_time,
ref_event=target,
aggregation_method=sg.event_aggregation_method(target),
aggregation_method=aggregation_method or sg.event_aggregation_method(target),
)
if per_context
else data
Expand Down
72 changes: 67 additions & 5 deletions src/seismometer/data/pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@ def _merge_with_strategy(
return pd.merge(predictions, one_event_filtered, on=pks, how="left")


def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
def max_aggregation(
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
) -> pd.DataFrame:
"""
Aggregates the DataFrame by selecting the maximum score value.

Expand All @@ -431,6 +433,8 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
The column name containing the time to consider, by default None.
ref_event : Optional[str], optional
The column name containing the event to consider, by default None.
threshold : Optional[float], optional
Score threshold to compare against, by default None.

Returns
-------
Expand All @@ -446,7 +450,9 @@ def max_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
return df.drop_duplicates(subset=pks)


def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
def min_aggregation(
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
) -> pd.DataFrame:
"""
Aggregates the DataFrame by selecting the minimum score value.

Expand All @@ -462,6 +468,8 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
The column name containing the time to consider, by default None.
ref_event : Optional[str], optional
The column name containing the event to consider, by default None.
threshold : Optional[float], optional
Score threshold to compare against, by default None.

Returns
-------
Expand All @@ -477,7 +485,9 @@ def min_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str,
return df.drop_duplicates(subset=pks)


def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
def first_aggregation(
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
) -> pd.DataFrame:
"""
Aggregates the DataFrame by selecting the first occurrence based on event time.

Expand All @@ -493,6 +503,8 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st
The column name containing the time to consider, by default None.
ref_event : Optional[str], optional
The column name containing the event to consider, by default None.
threshold : Optional[float], optional
Score threshold to compare against, by default None.

Returns
-------
Expand All @@ -508,7 +520,51 @@ def first_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: st
return df.drop_duplicates(subset=pks)


def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str) -> pd.DataFrame:
def first_above_threshold_aggregation(
df: pd.DataFrame,
pks: list[str],
score: str,
ref_time: Optional[str],
ref_event: Optional[str],
threshold: float,
) -> pd.DataFrame:
"""
Aggregates by selecting the first prediction with a score above the given threshold.

Parameters
----------
df : pd.DataFrame
The dataframe to aggregate.
pks : list[str]
Keys to group by.
score : str
Score column name.
ref_time : str
Time reference column name.
ref_event : str
Not used here but retained for API consistency.
threshold : float
Score threshold to compare against.

Returns
-------
pd.DataFrame
Aggregated dataframe with first above-threshold score per group.
"""
ref_score = _resolve_score_col(df, score)
if ref_time is None:
raise ValueError("ref_time is required for first_above_threshold aggregation")

reference_time = _resolve_time_col(df, ref_time)
df = df[df[ref_score] > threshold]
df = df[df[reference_time].notna()]
df = df.sort_values(by=reference_time)
return df.drop_duplicates(subset=pks)


def last_aggregation(
df: pd.DataFrame, pks: list[str], score: str, ref_time: str, ref_event: str, threshold: float = None
) -> pd.DataFrame:
"""
Aggregates the DataFrame by selecting the last occurrence based on event time.

Expand All @@ -524,6 +580,8 @@ def last_aggregation(df: pd.DataFrame, pks: list[str], score: str, ref_time: str
The column name containing the time to consider, by default None.
ref_event : Optional[str], optional
The column name containing the event to consider, by default None.
threshold : Optional[float], optional
Score threshold to compare against, by default None.

Returns
-------
Expand All @@ -546,6 +604,7 @@ def event_score(
ref_time: Optional[str] = None,
ref_event: Optional[str] = None,
aggregation_method: str = "max",
threshold: Optional[float] = None,
) -> pd.DataFrame:
"""
Reduces a dataframe of all predictions to a single row of significance; such as the max or most recent value for
Expand Down Expand Up @@ -573,6 +632,8 @@ def event_score(
the aggregation_method.
aggregation_method : str, optional
A string describing the method to select a value, by default 'max'.
threshold : Optional[float], optional
Score threshold to compare against, by default None.

Returns
-------
Expand All @@ -590,12 +651,13 @@ def event_score(
"min": min_aggregation,
"first": first_aggregation,
"last": last_aggregation,
"first_above_threshold": first_above_threshold_aggregation,
}

if aggregation_method not in aggregation_methods:
raise ValueError(f"Unknown aggregation method: {aggregation_method}")

df = aggregation_methods[aggregation_method](merged_frame, pks, score, ref_time, ref_event)
df = aggregation_methods[aggregation_method](merged_frame, pks, score, ref_time, ref_event, threshold)
return df.loc[~np.isnan(df.index)]


Expand Down
7 changes: 7 additions & 0 deletions src/seismometer/table/analytics_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
statistics_data: Optional[pd.DataFrame] = None,
per_context: bool = False,
censor_threshold: int = 10,
aggregation_method: Optional[str] = None,
):
"""
Initializes the AnalyticsTable object with the necessary data and parameters.
Expand Down Expand Up @@ -94,6 +95,10 @@ def __init__(
censor_threshold : int, optional
Minimum number of rows required in the cohort data to enable the generation of an analytics table,
by default 10.
aggregation_method : Optional[str], optional
If provided, indicates that the table is being used to summarize threshold-specific aggregation results.
This parameter is not used directly in this class, but it can be useful for customizing the title
or other aspects of the table when it is part of a threshold aggregation analysis, by default None.

Raises
------
Expand Down Expand Up @@ -132,6 +137,7 @@ def __init__(
self._initializing = False
self.per_context = per_context
self.censor_threshold = censor_threshold
self.aggregation_method = aggregation_method

def _validate_df_statistics_data(self):
if not self._initializing: # Skip validation during initial setup
Expand Down Expand Up @@ -376,6 +382,7 @@ def _generate_table_data(self) -> Optional[pd.DataFrame]:
metrics_to_display=self.metrics_to_display,
decimals=self.decimals,
censor_threshold=self.censor_threshold,
aggregation_method=self.aggregation_method,
)
if data is None:
return None
Expand Down
Loading