Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,4 @@ example-notebooks/binary-classifier/data/*
example-notebooks/binary-classifier/outputs/*

.seismometer_cache/
.DS_Store
1 change: 1 addition & 0 deletions changelog/174.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for an other value in cohort configuration
6 changes: 6 additions & 0 deletions src/seismometer/configuration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ class Cohort(BaseModel):
splits: Optional[list[Any]] = []
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like splits is already used for categorical values -> where it does a change to np.nan for unknown columns, which might raise an error (should test)

""" An optional list of 'inner edges' used to create a set of cohorts from a continuous attribute."""

top_k: Optional[int] = None
"""If set, only the top K most common values will be selected; all others grouped into 'Other'."""

other_value: Union[float, str] = None
"""Value to use for the 'Other' category when grouping less common values. Only used if top_k is set."""

@field_validator("display_name")
def default_display_name(cls, display_name: str, values: dict) -> str:
"""Ensures that display_name exists, setting it to the source name if not provided."""
Expand Down
69 changes: 68 additions & 1 deletion src/seismometer/data/cohorts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -185,6 +185,73 @@ def resolve_col_data(df: pd.DataFrame, feature: Union[str, pd.Series]) -> pd.Ser
# region Labels


@export
def resolve_top_k_cohorts(series: SeriesOrArray, top_k: int, other_value: Any = None) -> pd.Series:
"""
Extract the top K most frequent values from a series and replace all other values.

This function identifies the most common values in the input series based on their
frequency, keeps them unchanged, and replaces all less common values with the
specified 'other_value'. The result is returned as a categorical Pandas Series
with the top K values and the 'other_value' as categories.

The input data series containing values to be processed.
The number of most frequent values to preserve.
The value to use for replacing less frequent values.
If None, np.nan will be used (default).

A categorical Series with the same length as the input, where only
the top K most frequent values are preserved and all other values are
replaced with other_value.

Examples
--------
>>> import pandas as pd
>>> import numpy as np
>>> s = pd.Series(['A', 'B', 'A', 'C', 'B', 'D', 'A'])
>>> resolve_top_k_cohorts(s, top_k=2)
0 A
1 B
2 A
3 NaN
4 B
5 NaN
6 A
dtype: category
Categories (3, object): ['A', 'B', nan]

Parameters
----------
series : SeriesOrArray
The input data series.
top_k : int
The number of top cohorts to select.
other_value : Any, optional
The value to use for the 'Other' category. (default: np.nan for numeric, "Other" otherwise)

Returns
-------
pd.Series
A series with values not in the top K replaced with other_value.
"""
# Choose appropriate default value based on series dtype
if other_value is None:
if pd.api.types.is_numeric_dtype(series):
other_value = np.nan
else:
other_value = "Other"

# Get the value counts and select the top K values
top_k_values = series.value_counts().nlargest(top_k).index

# Replace values not in top_k_values with other_value
resolved = series.where(series.isin(top_k_values), other_value)

# Optionally, make it a categorical series with top_k_values + [other_value] as categories
resolved = pd.Categorical(resolved, categories=list(top_k_values) + [other_value])
return pd.Series(resolved, name=series.name)


@export
def resolve_cohorts(series: SeriesOrArray, splits: Optional[List] = None) -> pd.Series:
"""
Expand Down
12 changes: 11 additions & 1 deletion src/seismometer/seismogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from seismometer.configuration.model import Metric
from seismometer.core.patterns import Singleton
from seismometer.data import pandas_helpers as pdh
from seismometer.data import resolve_cohorts
from seismometer.data import resolve_cohorts, resolve_top_k_cohorts
from seismometer.data.loader import SeismogramLoader
from seismometer.report.alerting import AlertConfigProvider

Expand Down Expand Up @@ -381,6 +381,16 @@ def create_cohorts(self) -> None:
except IndexError as exc:
logger.warning(f"Failed to resolve cohort {disp_attr}: {exc}")
continue
elif cohort.top_k is not None:
try:
new_col = resolve_top_k_cohorts(
self.dataframe[cohort.source],
top_k=cohort.top_k,
other_value=cohort.other_value,
)
except ValueError as exc:
logger.warning(f"Failed to resolve top K cohorts {disp_attr}: {exc}")
continue
else:
new_col = pd.Series(pd.Categorical(self.dataframe[cohort.source]))

Expand Down
32 changes: 28 additions & 4 deletions tests/configuration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,30 +219,54 @@ def test_multiple_sources_require_display(self):

class TestCohort:
def test_default_values(self):
expected = {"source": "source", "display_name": "source", "splits": []}
expected = {"source": "source", "display_name": "source", "splits": [], "top_k": None, "other_value": None}
cohort = undertest.Cohort(source="source")

assert expected == cohort.model_dump()

def test_set_displayname(self):
expected = {"source": "source", "display_name": "display", "splits": []}
expected = {"source": "source", "display_name": "display", "splits": [], "top_k": None, "other_value": None}
cohort = undertest.Cohort(source="source", display_name="display")

assert expected == cohort.model_dump()

def test_allows_splits(self):
split_list = ["split1", "split2"]
expected = {"source": "source", "display_name": "source", "splits": split_list}
expected = {
"source": "source",
"display_name": "source",
"splits": split_list,
"top_k": None,
"other_value": None,
}
cohort = undertest.Cohort(source="source", splits=split_list)

assert expected == cohort.model_dump()

def test_strips_other_keys(self):
expected = {"source": "source", "display_name": "source", "splits": []}
expected = {"source": "source", "display_name": "source", "splits": [], "top_k": None, "other_value": None}
cohort = undertest.Cohort(source="source", other="other")

assert expected == cohort.model_dump()

def test_allows_other_string(self):
expected = {
"source": "source",
"display_name": "source",
"splits": [],
"top_k": 10,
"other_value": "SmallCounts",
}
cohort = undertest.Cohort(source="source", top_k=10, other_value="SmallCounts")

assert expected == cohort.model_dump()

def test_allows_other_number(self):
expected = {"source": "source", "display_name": "source", "splits": [], "top_k": 10, "other_value": 5}
cohort = undertest.Cohort(source="source", top_k=10, other_value=5)

assert expected == cohort.model_dump()


class TestDataUsage:
@pytest.mark.parametrize(
Expand Down
Loading
Loading