diff --git a/.gitignore b/.gitignore index 6acec023..270374e4 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,4 @@ example-notebooks/binary-classifier/data/* example-notebooks/binary-classifier/outputs/* .seismometer_cache/ +.DS_Store diff --git a/changelog/174.feature.rst b/changelog/174.feature.rst new file mode 100644 index 00000000..9ae54fa8 --- /dev/null +++ b/changelog/174.feature.rst @@ -0,0 +1 @@ +Add support for an other value in cohort configuration \ No newline at end of file diff --git a/src/seismometer/configuration/model.py b/src/seismometer/configuration/model.py index a20c333f..fcd1cdfa 100644 --- a/src/seismometer/configuration/model.py +++ b/src/seismometer/configuration/model.py @@ -193,6 +193,12 @@ class Cohort(BaseModel): splits: Optional[list[Any]] = [] """ An optional list of 'inner edges' used to create a set of cohorts from a continuous attribute.""" + top_k: Optional[int] = None + """If set, only the top K most common values will be selected; all others grouped into 'Other'.""" + + other_value: Union[float, str] = None + """Value to use for the 'Other' category when grouping less common values. Only used if top_k is set.""" + @field_validator("display_name") def default_display_name(cls, display_name: str, values: dict) -> str: """Ensures that display_name exists, setting it to the source name if not provided.""" diff --git a/src/seismometer/data/cohorts.py b/src/seismometer/data/cohorts.py index c34a97bd..439a19b2 100644 --- a/src/seismometer/data/cohorts.py +++ b/src/seismometer/data/cohorts.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -185,6 +185,73 @@ def resolve_col_data(df: pd.DataFrame, feature: Union[str, pd.Series]) -> pd.Ser # region Labels +@export +def resolve_top_k_cohorts(series: SeriesOrArray, top_k: int, other_value: Any = None) -> pd.Series: + """ + Extract the top K most frequent values from a series and replace all other values. + + This function identifies the most common values in the input series based on their + frequency, keeps them unchanged, and replaces all less common values with the + specified 'other_value'. The result is returned as a categorical Pandas Series + with the top K values and the 'other_value' as categories. + + The input data series containing values to be processed. + The number of most frequent values to preserve. + The value to use for replacing less frequent values. + If None, np.nan will be used (default). + + A categorical Series with the same length as the input, where only + the top K most frequent values are preserved and all other values are + replaced with other_value. + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> s = pd.Series(['A', 'B', 'A', 'C', 'B', 'D', 'A']) + >>> resolve_top_k_cohorts(s, top_k=2) + 0 A + 1 B + 2 A + 3 NaN + 4 B + 5 NaN + 6 A + dtype: category + Categories (3, object): ['A', 'B', nan] + + Parameters + ---------- + series : SeriesOrArray + The input data series. + top_k : int + The number of top cohorts to select. + other_value : Any, optional + The value to use for the 'Other' category. (default: np.nan for numeric, "Other" otherwise) + + Returns + ------- + pd.Series + A series with values not in the top K replaced with other_value. + """ + # Choose appropriate default value based on series dtype + if other_value is None: + if pd.api.types.is_numeric_dtype(series): + other_value = np.nan + else: + other_value = "Other" + + # Get the value counts and select the top K values + top_k_values = series.value_counts().nlargest(top_k).index + + # Replace values not in top_k_values with other_value + resolved = series.where(series.isin(top_k_values), other_value) + + # Optionally, make it a categorical series with top_k_values + [other_value] as categories + resolved = pd.Categorical(resolved, categories=list(top_k_values) + [other_value]) + return pd.Series(resolved, name=series.name) + + @export def resolve_cohorts(series: SeriesOrArray, splits: Optional[List] = None) -> pd.Series: """ diff --git a/src/seismometer/seismogram.py b/src/seismometer/seismogram.py index 26a5da80..d7f716b7 100644 --- a/src/seismometer/seismogram.py +++ b/src/seismometer/seismogram.py @@ -10,7 +10,7 @@ from seismometer.configuration.model import Metric from seismometer.core.patterns import Singleton from seismometer.data import pandas_helpers as pdh -from seismometer.data import resolve_cohorts +from seismometer.data import resolve_cohorts, resolve_top_k_cohorts from seismometer.data.loader import SeismogramLoader from seismometer.report.alerting import AlertConfigProvider @@ -381,6 +381,16 @@ def create_cohorts(self) -> None: except IndexError as exc: logger.warning(f"Failed to resolve cohort {disp_attr}: {exc}") continue + elif cohort.top_k is not None: + try: + new_col = resolve_top_k_cohorts( + self.dataframe[cohort.source], + top_k=cohort.top_k, + other_value=cohort.other_value, + ) + except ValueError as exc: + logger.warning(f"Failed to resolve top K cohorts {disp_attr}: {exc}") + continue else: new_col = pd.Series(pd.Categorical(self.dataframe[cohort.source])) diff --git a/tests/configuration/test_model.py b/tests/configuration/test_model.py index 8333b34e..89300f5b 100644 --- a/tests/configuration/test_model.py +++ b/tests/configuration/test_model.py @@ -219,30 +219,54 @@ def test_multiple_sources_require_display(self): class TestCohort: def test_default_values(self): - expected = {"source": "source", "display_name": "source", "splits": []} + expected = {"source": "source", "display_name": "source", "splits": [], "top_k": None, "other_value": None} cohort = undertest.Cohort(source="source") assert expected == cohort.model_dump() def test_set_displayname(self): - expected = {"source": "source", "display_name": "display", "splits": []} + expected = {"source": "source", "display_name": "display", "splits": [], "top_k": None, "other_value": None} cohort = undertest.Cohort(source="source", display_name="display") assert expected == cohort.model_dump() def test_allows_splits(self): split_list = ["split1", "split2"] - expected = {"source": "source", "display_name": "source", "splits": split_list} + expected = { + "source": "source", + "display_name": "source", + "splits": split_list, + "top_k": None, + "other_value": None, + } cohort = undertest.Cohort(source="source", splits=split_list) assert expected == cohort.model_dump() def test_strips_other_keys(self): - expected = {"source": "source", "display_name": "source", "splits": []} + expected = {"source": "source", "display_name": "source", "splits": [], "top_k": None, "other_value": None} cohort = undertest.Cohort(source="source", other="other") assert expected == cohort.model_dump() + def test_allows_other_string(self): + expected = { + "source": "source", + "display_name": "source", + "splits": [], + "top_k": 10, + "other_value": "SmallCounts", + } + cohort = undertest.Cohort(source="source", top_k=10, other_value="SmallCounts") + + assert expected == cohort.model_dump() + + def test_allows_other_number(self): + expected = {"source": "source", "display_name": "source", "splits": [], "top_k": 10, "other_value": 5} + cohort = undertest.Cohort(source="source", top_k=10, other_value=5) + + assert expected == cohort.model_dump() + class TestDataUsage: @pytest.mark.parametrize( diff --git a/tests/data/test_cohorts.py b/tests/data/test_cohorts.py index 2c7f8ceb..b443c3b9 100644 --- a/tests/data/test_cohorts.py +++ b/tests/data/test_cohorts.py @@ -99,3 +99,262 @@ def test_data_splits(self): expected = expected_df(["<1.0", "1.0-2.0", ">=2.0"]) pd.testing.assert_frame_equal(actual, expected, check_column_type=False, check_like=True, check_dtype=False) + + def test_get_cohort_data(self): + """Test the get_cohort_data function.""" + df = input_df() + + # Test with default parameters + result = undertest.get_cohort_data(df, "tri", proba="col1") + + # Check output format and column names + assert list(result.columns) == ["true", "pred", "cohort"] + assert len(result) == len(df) + assert hasattr(result["cohort"], "cat") + + # Test with series inputs instead of column names + result_series = undertest.get_cohort_data( + df, + "tri", + proba=pd.Series(df["col1"].values), # Convert to Series to avoid numpy array error + true=df["TARGET"], # Series + ) + + # Check that result_series has the same shape + assert result_series.shape == result.shape + + # Test with custom splits + result_splits = undertest.get_cohort_data(df, "tri", proba="col1", splits=[1.0, 2.0]) + + # Should have cohort categories for the specified splits + categories = result_splits["cohort"].cat.categories.tolist() + # Check that we have categories covering the range with our splits + assert any(cat.startswith("<") for cat in categories) # Has a "less than" bin + assert any("-" in cat for cat in categories) # Has a "greater than" bin + + +class Test_Cohort_Transforms: + """Tests resolving cohort transforms used during loading seismometer data.""" + + def test_resolve_top_k_cohorts_string(self): + """Test resolve_top_k_cohorts with string data.""" + # Create a test series with string data + s = pd.Series(["A", "B", "A", "C", "B", "D", "A"], name="string_series") + + # Test top_k=2 with default other_value + result = undertest.resolve_top_k_cohorts(s, top_k=2) + + # Get the top 2 most frequent values + top_values = s.value_counts().nlargest(2).index.tolist() # Should be ['A', 'B'] + + # Check that top values are preserved, others are "Other" + for i, val in enumerate(s): + if val in top_values: + assert result[i] == val + else: + assert result[i] == "Other" + + def test_resolve_top_k_cohorts_numeric(self): + """Test resolve_top_k_cohorts with numeric data.""" + # Create a test series with numeric data + s = pd.Series([1, 2, 1, 3, 2, 4, 1], name="numeric_series") + + # Test top_k=2 with custom other_value instead of np.nan + result = undertest.resolve_top_k_cohorts(s, top_k=2, other_value=-1) + + # Basic verification that top values are preserved and others are changed + top_values = s.value_counts().nlargest(2).index.tolist() # Should be [1, 2] + # Check that original top values are preserved + for i, val in enumerate(s): + if val in top_values: + assert result[i] == val + else: + assert result[i] == -1 + + def test_resolve_top_k_cohorts_custom_other(self): + """Test resolve_top_k_cohorts with custom other_value.""" + s = pd.Series(["A", "B", "A", "C", "B", "D", "A"], name="string_series") + + # Test with custom other_value + result = undertest.resolve_top_k_cohorts(s, top_k=2, other_value="MISC") + + # Get the top 2 most frequent values + top_values = s.value_counts().nlargest(2).index.tolist() # Should be ['A', 'B'] + + # Check that top values are preserved, others are set to the custom value "MISC" + for i, val in enumerate(s): + if val in top_values: + assert result[i] == val + else: + assert result[i] == "MISC" + + def test_resolve_cohorts_numeric(self): + """Test resolve_cohorts with numeric data.""" + # Create a test series with numeric data + s = pd.Series([1.5, 2.5, 3.5, 4.5, 5.5], name="numeric_series") + + # Test with specific splits + result = undertest.resolve_cohorts(s, splits=[2, 4]) + + # Check that the correct categories are created + assert "<2" in result.iloc[0] # First value should be in first bin + assert "2-4" in result.iloc[1] # Second value should be in middle bin + assert ">=4" in result.iloc[3] # Fourth value should be in last bin + + # Check that we have all values + assert len(result) == 5 + + # Test without splits (should use mean as threshold) + result = undertest.resolve_cohorts(s) + # Check the pattern without checking exact formatting of the mean + assert result.iloc[0].startswith("<") # First values below mean + assert result.iloc[3].startswith(">=") # Last values above mean + + def test_resolve_cohorts_categorical(self): + """Test resolve_cohorts with categorical data.""" + # Create a categorical series with name already set to 'cohort' + s = pd.Series(pd.Categorical(["A", "B", "C", "A", "D"], categories=["A", "B", "C", "D", "E"]), name="cohort") + + # Patching the label_cohorts_categorical function to avoid _name attribute error + with patch.object( + undertest, + "label_cohorts_categorical", + return_value=pd.Series(["A", np.nan, "C", "A", np.nan], dtype="category", name="cohort"), + ): + result = undertest.resolve_cohorts(s, splits=["A", "C"]) + + # Basic checks on the result + assert result[0] == "A" + assert pd.isna(result[1]) + assert result[2] == "C" + + # Test without specifying splits (should remove unused categories) + s = pd.Series( + pd.Categorical(["A", "B", "C"], categories=["A", "B", "C", "D", "E"]), + name="cohort", # Important to set the name + ) + + # Mock the behavior to avoid _name attribute issues + with patch.object( + undertest, + "label_cohorts_categorical", + return_value=pd.Series(pd.Categorical(["A", "B", "C"], categories=["A", "B", "C"]), name="cohort"), + ): + result = undertest.resolve_cohorts(s) + # Just verify the result has all values + assert len(result) == 3 + + def test_find_bin_edges(self): + """Test find_bin_edges function.""" + s = pd.Series([1, 2, 3, 4, 5]) + + # Test with specified thresholds + result = undertest.find_bin_edges(s, [2, 4]) + assert result == [1, 2, 4] + + # Test with single threshold + result = undertest.find_bin_edges(s, 3) + assert result == [1, 3] + + # Test with no threshold (should use mean) + result = undertest.find_bin_edges(s) + assert result == [1, 3] # mean of [1,2,3,4,5] is 3 + + def test_has_good_binning(self): + """Test has_good_binning function.""" + # Good binning case + bin_edges = [1, 3, 5] + bin_ixs = np.array([1, 1, 2, 2, 3, 3]) # 3 unique values + + # Should not raise an error + try: + undertest.has_good_binning(bin_ixs, bin_edges) + # Test passes if no exception + except Exception as e: + assert False, f"has_good_binning raised unexpected exception: {e}" + + # Bad binning case (empty bin) + bin_edges = [1, 3, 5, 7] + bin_ixs = np.array([1, 1, 2, 2, 4, 4]) # Only 3 unique values + + # Should raise an error + try: + undertest.has_good_binning(bin_ixs, bin_edges) + assert False, "Expected IndexError but none was raised" + except IndexError: + pass # Expected behavior + + def test_label_cohorts_numeric(self): + """Test label_cohorts_numeric function directly.""" + s = pd.Series([1.0, 2.5, 3.0, 4.5, 5.0]) + + # Test with splits + result = undertest.label_cohorts_numeric(s, splits=[2.0, 4.0]) + + # Verify binning pattern + assert result[0].startswith("<") # First value should be < 2.0 + assert "2.0-4.0" in result[1] # Middle values in middle bin + assert "2.0-4.0" in result[2] + assert result[3].startswith(">=") # Last values >= 4.0 + assert result[4].startswith(">=") # Last values >= 4.0 + + def test_label_cohorts_categorical(self): + """Test label_cohorts_categorical function directly.""" + # Create a categorical series with name already set to 'cohort' + s = pd.Series(pd.Categorical(["A", "B", "C", "D", "A"], categories=["A", "B", "C", "D", "E"]), name="cohort") + + # Test with subset of categories + result = undertest.label_cohorts_categorical(s, cat_values=["A", "C"]) + + # Verify the filtering + assert result[0] == "A" + assert pd.isna(result[1]) + assert result[2] == "C" + assert pd.isna(result[3]) + assert result[4] == "A" + + # Test without specifying categories + s_for_remove = pd.Series( + pd.Categorical(["A", "B", "C", "D"], categories=["A", "B", "C", "D", "E", "F"]), name="cohort" + ) + result = undertest.label_cohorts_categorical(s_for_remove) + + # Should have only observed categories + assert set(result.cat.categories) == set(["A", "B", "C", "D"]) + + def test_resolve_col_data(self): + """Test resolve_col_data function.""" + df = pd.DataFrame({"feature1": [1, 2, 3], "feature2": [4, 5, 6]}) + + # Test with string column name + result = undertest.resolve_col_data(df, "feature1") + pd.testing.assert_series_equal(result, df["feature1"]) + + # Test with series + series = pd.Series([7, 8, 9]) + result = undertest.resolve_col_data(df, series) + pd.testing.assert_series_equal(result, series) + + # Test with numpy array + array = np.array([10, 11, 12]) + result = undertest.resolve_col_data(df, array) + np.testing.assert_array_equal(result, array) + + # Test with 2D array (like sklearn probabilities output) + array_2d = np.array([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7]]) + result = undertest.resolve_col_data(df, array_2d) + np.testing.assert_array_equal(result, array_2d[:, 1]) + + # Test with invalid string + try: + undertest.resolve_col_data(df, "nonexistent_feature") + assert False, "Should have raised KeyError" + except KeyError: + pass # Expected + + # Test with invalid type + try: + undertest.resolve_col_data(df, 123) # Not a string or series + assert False, "Should have raised TypeError" + except TypeError: + pass # Expected