diff --git a/changelog/187.bugfix.rst b/changelog/187.bugfix.rst
new file mode 100644
index 00000000..cbdd7210
--- /dev/null
+++ b/changelog/187.bugfix.rst
@@ -0,0 +1 @@
+Fixed array handling and index alignment in cohorts, non-string value representation in filter rules, imputation row counting in pandas helpers, and pandas ``FutureWarning`` in ``add_unseen``.
diff --git a/src/seismometer/data/cohorts.py b/src/seismometer/data/cohorts.py
index c34a97bd..08789869 100644
--- a/src/seismometer/data/cohorts.py
+++ b/src/seismometer/data/cohorts.py
@@ -151,20 +151,21 @@ def get_cohort_performance_data(
return frame
-def resolve_col_data(df: pd.DataFrame, feature: Union[str, pd.Series]) -> pd.Series:
+def resolve_col_data(df: pd.DataFrame, feature: Union[str, SeriesOrArray]) -> pd.Series:
"""
- Handles resolving feature from either being a series or specifying a series in the dataframe.
+ Handles resolving feature from either being a series, array, or specifying a series in the dataframe.
Parameters
----------
df : pd.DataFrame
Containing a column of name feature if feature is passed in as a string.
- feature : Union[str, pd.Series]
- Either a pandas.Series or a column name in the dataframe.
+ feature : Union[str, SeriesOrArray]
+ Either a pandas.Series, numpy array, or a column name in the dataframe.
Returns
-------
- pd.Series.
+ pd.Series
+ Always returns a pandas Series, with index matching df.index for array inputs.
"""
if isinstance(feature, str):
@@ -172,13 +173,16 @@ def resolve_col_data(df: pd.DataFrame, feature: Union[str, pd.Series]) -> pd.Ser
return df[feature].copy()
else:
raise KeyError(f"Feature {feature} was not found in dataframe")
+ elif isinstance(feature, pd.Series):
+ return feature # Already a Series, preserve its index
elif hasattr(feature, "ndim"):
+ # Convert arrays to Series with df's index for proper alignment
if feature.ndim > 1: # probas from sklearn is nx2 with second column being the positive predictions
- return feature[:, 1]
+ return pd.Series(feature[:, 1], index=df.index)
else:
- return feature
+ return pd.Series(feature, index=df.index)
else:
- raise TypeError("Feature must be a string or pandas.Series, was given a ", type(feature))
+ raise TypeError(f"Feature must be a string, pandas.Series, or numpy.ndarray, was given {type(feature)}")
# endregion
@@ -232,7 +236,11 @@ def label_cohorts_numeric(series: SeriesOrArray, splits: Optional[List] = None)
labels = [f"{bins[i]}-{bins[i+1]}" for i in range(len(bins) - 1)] + [f">={bins[-1]}"]
labels[0] = f"<{bins[1]}"
cat = pd.Categorical.from_codes(bin_ixs - 1, labels)
- return pd.Series(cat)
+ # Preserve the input series index for proper alignment in pd.concat
+ if isinstance(series, pd.Series):
+ return pd.Series(cat, index=series.index)
+ else:
+ return pd.Series(cat)
def has_good_binning(bin_ixs: List, bin_edges: List) -> None:
@@ -275,7 +283,7 @@ def label_cohorts_categorical(series: SeriesOrArray, cat_values: Optional[list]
List of string labels for each bin; which is the list of categories.
"""
series.name = "cohort"
- series.cat._name = "cohort" # CategoricalAccessors have a different name..
+ series.cat._name = "cohort" # CategoricalAccessors have a different name.
# If no splits specified, restrict to observed values
if cat_values is None:
diff --git a/src/seismometer/data/filter.py b/src/seismometer/data/filter.py
index 06b3b1a7..d5b1c6e5 100644
--- a/src/seismometer/data/filter.py
+++ b/src/seismometer/data/filter.py
@@ -211,9 +211,9 @@ def __str__(self) -> str:
case "notna":
return f"{self.left} has a value"
case "isin":
- return f"{self.left} is in: {', '.join(self.right)}"
+ return f"{self.left} is in: {', '.join(map(str, self.right))}"
case "notin":
- return f"{self.left} not in: {', '.join(self.right)}"
+ return f"{self.left} not in: {', '.join(map(str, self.right))}"
case "topk":
return f"{self.left} in top {self.right} values"
case "nottopk":
diff --git a/src/seismometer/data/pandas_helpers.py b/src/seismometer/data/pandas_helpers.py
index 00559c12..f3014f78 100644
--- a/src/seismometer/data/pandas_helpers.py
+++ b/src/seismometer/data/pandas_helpers.py
@@ -259,9 +259,9 @@ def post_process_event(
# cast after imputation - supports nonnullable types
try_casting(dataframe, label_col, column_dtype)
- # Log how many rows were imputed/changed
- imputed_with_time = ((label_na_map & ~time_na_map) & dataframe[label_col].notna()).sum()
- imputed_no_time = (dataframe[label_col].isna()).sum()
+ # Log how many rows were imputed
+ imputed_with_time = (label_na_map & ~time_na_map).sum()
+ imputed_no_time = (label_na_map & time_na_map).sum()
logger.debug(
f"Post-processing of events for {label_col} and {time_col} complete. "
f"Imputed {imputed_with_time} rows with time, {imputed_no_time} rows with no time."
@@ -443,13 +443,15 @@ def _merge_with_strategy(
if merge_strategy == "first":
logger.debug(f"Updating events to only keep the first occurrence for each {event_display}.")
one_event_filtered = one_event.groupby(pks).first().reset_index()
- if merge_strategy == "last":
+ elif merge_strategy == "last":
logger.debug(f"Updating events to only keep the last occurrence for each {event_display}.")
one_event_filtered = one_event.groupby(pks).last().reset_index()
except ValueError as e:
logger.warning(e)
- pass
+ # Only continue with fallback merge if one_event_filtered was set
+ if "one_event_filtered" not in locals():
+ raise
return pd.merge(predictions, one_event_filtered, on=pks, how="left")
@@ -778,14 +780,15 @@ def _resolve_score_col(dataframe: pd.DataFrame, score: str) -> str:
def analytics_metric_name(metric_names: list[str], existing_metric_starts: list[str], column_name: str) -> str:
- """In the analytics table, often the provided column name is not the actual
+ """
+ In the analytics table, often the provided column name is not the actual
metric name that we want to log. Here, we extract the desired metric name.
Parameters
----------
metric_names : list[str]
What metrics already exist.
- existing_metric_values : list[str]
+ existing_metric_starts : list[str]
What strings can start the mangled column name.
column_name : str
The name of the column we are trying to make into a metric.
@@ -800,7 +803,7 @@ def analytics_metric_name(metric_names: list[str], existing_metric_starts: list[
else:
for value in existing_metric_starts:
if column_name.startswith(f"{value}_"):
- return column_name.lstrip(f"{value}_")
+ return column_name.removeprefix(f"{value}_")
return None
diff --git a/src/seismometer/plot/mpl/_util.py b/src/seismometer/plot/mpl/_util.py
index e01eb368..2dabd0fb 100644
--- a/src/seismometer/plot/mpl/_util.py
+++ b/src/seismometer/plot/mpl/_util.py
@@ -49,6 +49,10 @@ def add_unseen(df: pd.DataFrame, col="cohort") -> pd.DataFrame:
obs = df[col].unique()
unseen = [k for k in keys if k not in obs]
+ # Only concatenate if there are unseen categories
+ if not unseen:
+ return df
+
rv = pd.concat([df, pd.DataFrame({col: unseen})], ignore_index=True)
rv[col] = rv[col].astype(pd.CategoricalDtype(df[col].cat.categories))
return rv
diff --git a/tests/data/test_cohorts.py b/tests/data/test_cohorts.py
index 2c7f8ceb..f3bdfcd0 100644
--- a/tests/data/test_cohorts.py
+++ b/tests/data/test_cohorts.py
@@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
+import pytest
import seismometer.data.cohorts as undertest
import seismometer.data.performance # NoQA - used in patching
@@ -99,3 +100,239 @@ def test_data_splits(self):
expected = expected_df(["<1.0", "1.0-2.0", ">=2.0"])
pd.testing.assert_frame_equal(actual, expected, check_column_type=False, check_like=True, check_dtype=False)
+
+
+class TestGetCohortData:
+ """Tests for get_cohort_data() function - previously untested."""
+
+ def test_get_cohort_data_with_column_names(self):
+ """Test get_cohort_data with proba and true as column names."""
+ df = input_df()
+ result = undertest.get_cohort_data(df, "tri", proba="col1", true="TARGET")
+
+ assert "true" in result.columns
+ assert "pred" in result.columns
+ assert "cohort" in result.columns
+ assert len(result) == 6
+
+ def test_get_cohort_data_with_array_inputs(self):
+ """Test get_cohort_data with proba and true as arrays."""
+ df = input_df()
+ proba_array = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+ true_array = np.array([1, 0, 0, 1, 0, 1])
+
+ result = undertest.get_cohort_data(df, "tri", proba=proba_array, true=true_array)
+
+ # Verify correct columns and row count
+ assert len(result) == 6
+ assert "pred" in result.columns
+ assert "true" in result.columns
+ assert "cohort" in result.columns
+ # Values should match input arrays
+ assert list(result["pred"].values) == list(proba_array)
+ assert list(result["true"].values) == list(true_array)
+
+ def test_get_cohort_data_with_mismatched_array_lengths(self):
+ """Test get_cohort_data documents edge case behavior with mismatched lengths."""
+ df = input_df()
+ proba_series = pd.Series([0.2, 0.3], index=[0, 1]) # Only 2 rows
+
+ # Pandas will align by index, then dropna removes mismatched indices
+ result = undertest.get_cohort_data(df, "tri", proba=proba_series, true="TARGET")
+
+ # Documents behavior: only matching indices kept
+ assert len(result) >= 0 # May be 0-2 depending on cohort column alignment
+
+ def test_get_cohort_data_with_nan_values(self):
+ """Test get_cohort_data drops NaN values."""
+ df = pd.DataFrame({"TARGET": [1, 0, np.nan, 1], "col1": [0.2, np.nan, 0.4, 0.5], "tri": [0, 0, 1, 1]})
+
+ result = undertest.get_cohort_data(df, "tri", proba="col1", true="TARGET")
+
+ # Should drop rows with NaN (2 rows dropped)
+ assert len(result) == 2
+
+ def test_get_cohort_data_with_splits(self):
+ """Test get_cohort_data with custom splits parameter."""
+ df = input_df()
+
+ result = undertest.get_cohort_data(df, "tri", proba="col1", true="TARGET", splits=[1.0, 2.0])
+
+ # Should create cohorts based on splits
+ assert "cohort" in result.columns
+ assert result["cohort"].cat.categories.tolist() == ["<1.0", "1.0-2.0", ">=2.0"]
+
+
+class TestResolveColData:
+ """Tests for resolve_col_data() helper function - previously untested."""
+
+ def test_resolve_col_data_with_string_column(self):
+ """Test resolve_col_data with column name as string."""
+ df = pd.DataFrame({"col1": [1, 2, 3]})
+ result = undertest.resolve_col_data(df, "col1")
+
+ pd.testing.assert_series_equal(result, pd.Series([1, 2, 3], name="col1"))
+
+ def test_resolve_col_data_with_missing_column(self):
+ """Test resolve_col_data raises KeyError for missing column."""
+ df = pd.DataFrame({"col1": [1, 2, 3]})
+
+ with pytest.raises(KeyError, match="Feature missing_col was not found in dataframe"):
+ undertest.resolve_col_data(df, "missing_col")
+
+ def test_resolve_col_data_with_2d_array(self):
+ """Test resolve_col_data handles 2D array (sklearn probabilities)."""
+ df = pd.DataFrame({"col1": [1, 2, 3]})
+ proba_2d = np.array([[0.2, 0.8], [0.3, 0.7], [0.4, 0.6]])
+
+ result = undertest.resolve_col_data(df, proba_2d)
+
+ # Should return second column (positive class)
+ np.testing.assert_array_equal(result, np.array([0.8, 0.7, 0.6]))
+
+ def test_resolve_col_data_with_1d_array(self):
+ """Test resolve_col_data handles 1D array."""
+ df = pd.DataFrame({"col1": [1, 2, 3]})
+ array_1d = np.array([0.2, 0.3, 0.4])
+
+ result = undertest.resolve_col_data(df, array_1d)
+
+ np.testing.assert_array_equal(result, array_1d)
+
+ def test_resolve_col_data_with_invalid_type(self):
+ """Test resolve_col_data raises TypeError for invalid input."""
+ df = pd.DataFrame({"col1": [1, 2, 3]})
+
+ with pytest.raises(TypeError, match="Feature must be a string, pandas.Series, or numpy.ndarray"):
+ undertest.resolve_col_data(df, 123) # Invalid type
+
+
+class TestResolveCohorts:
+ """Tests for resolve_cohorts() function - previously untested."""
+
+ def test_resolve_cohorts_with_categorical_series(self):
+ """Test resolve_cohorts auto-dispatches to categorical handler."""
+ series = pd.Series(pd.Categorical(["A", "B", "A", "C"]), name="test_cohort")
+
+ result = undertest.resolve_cohorts(series)
+
+ assert isinstance(result, pd.Series)
+ assert hasattr(result, "cat")
+ assert set(result.cat.categories) == {"A", "B", "C"} # Unused removed
+
+ def test_resolve_cohorts_with_numeric_series(self):
+ """Test resolve_cohorts auto-dispatches to numeric handler."""
+ series = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
+
+ result = undertest.resolve_cohorts(series)
+
+ assert isinstance(result, pd.Series)
+ assert hasattr(result, "cat") # Should be categorical
+ # Should split at mean (3.0)
+ assert len(result.cat.categories) == 2
+
+ def test_resolve_cohorts_with_numeric_splits(self):
+ """Test resolve_cohorts with custom numeric splits."""
+ series = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
+
+ result = undertest.resolve_cohorts(series, splits=[2.5, 4.0])
+
+ assert result.cat.categories.tolist() == ["<2.5", "2.5-4.0", ">=4.0"]
+
+
+class TestHasGoodBinning:
+ """Tests for has_good_binning() error checking function - previously untested."""
+
+ def test_has_good_binning_with_valid_bins(self):
+ """Test has_good_binning passes with valid binning."""
+ bin_ixs = np.array([1, 1, 2, 2, 3, 3])
+ bin_edges = [0.0, 1.0, 2.0]
+
+ # Should not raise
+ undertest.has_good_binning(bin_ixs, bin_edges)
+
+ def test_has_good_binning_with_empty_bins(self):
+ """Test has_good_binning raises IndexError for empty bins."""
+ bin_ixs = np.array([1, 1, 3, 3]) # Missing bin 2
+ bin_edges = [0.0, 1.0, 2.0]
+
+ with pytest.raises(IndexError, match="Splits provided contain some empty bins"):
+ undertest.has_good_binning(bin_ixs, bin_edges)
+
+ def test_has_good_binning_with_single_bin(self):
+ """Test has_good_binning with single bin edge case."""
+ bin_ixs = np.array([1, 1, 1])
+ bin_edges = [0.0]
+
+ # Should not raise
+ undertest.has_good_binning(bin_ixs, bin_edges)
+
+
+class TestLabelCohortsCategorical:
+ """Tests for label_cohorts_categorical() function - previously untested."""
+
+ def test_label_cohorts_categorical_without_cat_values(self):
+ """Test label_cohorts_categorical removes unused categories."""
+ series = pd.Series(pd.Categorical(["A", "B", "A"], categories=["A", "B", "C", "D"]))
+
+ result = undertest.label_cohorts_categorical(series)
+
+ # Should remove unused categories C and D
+ assert set(result.cat.categories) == {"A", "B"}
+
+ def test_label_cohorts_categorical_with_cat_values_matching(self):
+ """Test label_cohorts_categorical with matching cat_values."""
+ series = pd.Series(pd.Categorical(["A", "B", "C"], categories=["A", "B", "C"]))
+
+ result = undertest.label_cohorts_categorical(series, cat_values=["A", "B", "C"])
+
+ # Should return as-is
+ pd.testing.assert_series_equal(result, series, check_names=False)
+
+ def test_label_cohorts_categorical_with_cat_values_filtering(self):
+ """Test label_cohorts_categorical filters to specified cat_values."""
+ series = pd.Series(pd.Categorical(["A", "B", "C", "D"], categories=["A", "B", "C", "D"]))
+
+ result = undertest.label_cohorts_categorical(series, cat_values=["A", "C"])
+
+ # Should filter to only A and C, rest become NaN
+ assert result.notna().sum() == 2
+ assert set(result.dropna()) == {"A", "C"}
+
+
+class TestFindBinEdges:
+ """Tests for find_bin_edges() function - previously untested."""
+
+ def test_find_bin_edges_with_no_thresholds(self):
+ """Test find_bin_edges defaults to mean split."""
+ series = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0]) # Mean = 3.0
+
+ result = undertest.find_bin_edges(series)
+
+ # Returns list with [min, mean]
+ assert len(result) == 2
+ assert result[0] == 1.0 # Series minimum
+ assert result[1] == 3.0 # Series mean
+
+ def test_find_bin_edges_with_custom_thresholds(self):
+ """Test find_bin_edges with custom threshold values."""
+ series = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
+
+ result = undertest.find_bin_edges(series, thresholds=[2.0, 4.0])
+
+ # Returns list with [min, threshold1, threshold2]
+ assert len(result) == 3
+ assert result[0] == 1.0 # Series minimum
+ assert result[1] == 2.0
+ assert result[2] == 4.0
+
+ def test_find_bin_edges_with_single_value_series(self):
+ """Test find_bin_edges with series containing single unique value."""
+ series = pd.Series([5.0, 5.0, 5.0])
+
+ result = undertest.find_bin_edges(series)
+
+ # Edge case: single value means min = mean
+ # Documents that this creates degenerate bins (both edges same)
+ assert len(result) >= 1
+ assert all(val == 5.0 for val in result) # All edges are 5.0
diff --git a/tests/data/test_filters.py b/tests/data/test_filters.py
index c3e75102..d74e83b4 100644
--- a/tests/data/test_filters.py
+++ b/tests/data/test_filters.py
@@ -37,67 +37,65 @@ def test_filter_universal_rule_equals(self, test_dataframe):
assert len(FilterRule.none().filter(test_dataframe)) == 0
assert all(FilterRule.none().filter(test_dataframe).columns == test_dataframe.columns)
- def test_filter_base_rule_equals(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("T/F", "==", 0).mask(test_dataframe).equals(test_dataframe["T/F"] == 0)
- assert FilterRule("T/F", "==", 0).filter(test_dataframe).equals(test_dataframe[test_dataframe["T/F"] == 0])
-
- def test_filter_base_rule_not_equals(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("T/F", "!=", 0).mask(test_dataframe).equals(test_dataframe["T/F"] != 0)
- assert FilterRule("T/F", "!=", 0).filter(test_dataframe).equals(test_dataframe[test_dataframe["T/F"] != 0])
-
- def test_filter_base_rule_isin(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert (
- FilterRule("Cat", "isin", ["A", "B"]).mask(test_dataframe).equals(test_dataframe["Cat"].isin(["A", "B"]))
- )
- assert (
- FilterRule("Cat", "isin", ["A", "B"])
- .filter(test_dataframe)
- .equals(test_dataframe[test_dataframe["Cat"].isin(["A", "B"])])
- )
-
- def test_filter_base_rule_notin(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert (
- FilterRule("Cat", "notin", ["A", "B"]).mask(test_dataframe).equals(~test_dataframe["Cat"].isin(["A", "B"]))
- )
- assert (
- FilterRule("Cat", "notin", ["A", "B"])
- .filter(test_dataframe)
- .equals(test_dataframe[~test_dataframe["Cat"].isin(["A", "B"])])
- )
-
- def test_filter_base_rule_less_than(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("Val", "<", 20).mask(test_dataframe).equals(test_dataframe["Val"] < 20)
- assert FilterRule("Val", "<", 20).filter(test_dataframe).equals(test_dataframe[test_dataframe["Val"] < 20])
-
- def test_filter_base_rule_greater_then(self, test_dataframe):
+ @pytest.mark.parametrize(
+ "column,operator,value,expected_mask_expr,expected_filter_expr",
+ [
+ # Comparison operators
+ ("T/F", "==", 0, lambda df: df["T/F"] == 0, lambda df: df[df["T/F"] == 0]),
+ ("T/F", "!=", 0, lambda df: df["T/F"] != 0, lambda df: df[df["T/F"] != 0]),
+ ("Val", "<", 20, lambda df: df["Val"] < 20, lambda df: df[df["Val"] < 20]),
+ ("Val", ">", 20, lambda df: df["Val"] > 20, lambda df: df[df["Val"] > 20]),
+ ("Val", "<=", 20, lambda df: df["Val"] <= 20, lambda df: df[df["Val"] <= 20]),
+ ("Val", ">=", 20, lambda df: df["Val"] >= 20, lambda df: df[df["Val"] >= 20]),
+ # Set operators
+ (
+ "Cat",
+ "isin",
+ ["A", "B"],
+ lambda df: df["Cat"].isin(["A", "B"]),
+ lambda df: df[df["Cat"].isin(["A", "B"])],
+ ),
+ (
+ "Cat",
+ "notin",
+ ["A", "B"],
+ lambda df: ~df["Cat"].isin(["A", "B"]),
+ lambda df: df[~df["Cat"].isin(["A", "B"])],
+ ),
+ # Null operators
+ ("T/F", "isna", None, lambda df: df["T/F"].isna(), lambda df: df[df["T/F"].isna()]),
+ ("T/F", "notna", None, lambda df: ~df["T/F"].isna(), lambda df: df[~df["T/F"].isna()]),
+ ],
+ ids=[
+ "equals",
+ "not_equals",
+ "less_than",
+ "greater_than",
+ "less_than_or_eq",
+ "greater_than_or_eq",
+ "isin",
+ "notin",
+ "isna",
+ "notna",
+ ],
+ )
+ def test_filter_base_rule_operators(
+ self, test_dataframe, column, operator, value, expected_mask_expr, expected_filter_expr
+ ):
+ """Test FilterRule operators with parametrization to reduce code duplication."""
FilterRule.MIN_ROWS = None
- assert FilterRule("Val", ">", 20).mask(test_dataframe).equals(test_dataframe["Val"] > 20)
- assert FilterRule("Val", ">", 20).filter(test_dataframe).equals(test_dataframe[test_dataframe["Val"] > 20])
- def test_filter_base_rule_less_than_or_eq(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("Val", "<=", 20).mask(test_dataframe).equals(test_dataframe["Val"] <= 20)
- assert FilterRule("Val", "<=", 20).filter(test_dataframe).equals(test_dataframe[test_dataframe["Val"] <= 20])
+ # Create the rule (handle operators that don't need a value)
+ if operator in ["isna", "notna"]:
+ rule = FilterRule(column, operator)
+ else:
+ rule = FilterRule(column, operator, value)
- def test_filter_base_rule_greater_then_or_eq(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("Val", ">=", 20).mask(test_dataframe).equals(test_dataframe["Val"] >= 20)
- assert FilterRule("Val", ">=", 20).filter(test_dataframe).equals(test_dataframe[test_dataframe["Val"] >= 20])
+ # Test mask
+ assert rule.mask(test_dataframe).equals(expected_mask_expr(test_dataframe))
- def test_filter_base_rule_isna(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("T/F", "isna").mask(test_dataframe).equals(test_dataframe["T/F"].isna())
- assert FilterRule("T/F", "isna").filter(test_dataframe).equals(test_dataframe[test_dataframe["T/F"].isna()])
-
- def test_filter_base_rule_notna(self, test_dataframe):
- FilterRule.MIN_ROWS = None
- assert FilterRule("T/F", "notna").mask(test_dataframe).equals(~test_dataframe["T/F"].isna())
- assert FilterRule("T/F", "notna").filter(test_dataframe).equals(test_dataframe[~test_dataframe["T/F"].isna()])
+ # Test filter
+ assert rule.filter(test_dataframe).equals(expected_filter_expr(test_dataframe))
@pytest.mark.parametrize(
"k, expected_values",
@@ -226,6 +224,145 @@ def test_from_filter_config_topk_behavior(self, monkeypatch, count, class_defaul
result = rule.filter(df)
assert sorted(result["Cat"].unique()) == expected_cats
+ def test_from_filter_config_include_with_values(self, monkeypatch):
+ """Test action='include' with values parameter creates isin rule."""
+ from seismometer.configuration.model import FilterConfig
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Cat": ["A", "B", "C", "D", "E"]})
+ config = FilterConfig(source="Cat", action="include", values=["A", "B"])
+ rule = FilterRule.from_filter_config(config)
+
+ assert rule.relation == "isin"
+ assert rule.left == "Cat"
+ assert set(rule.right) == {"A", "B"}
+
+ result = rule.filter(df)
+ assert sorted(result["Cat"].unique()) == ["A", "B"]
+
+ def test_from_filter_config_exclude_with_values(self, monkeypatch):
+ """Test action='exclude' with values parameter creates negated isin rule."""
+ from seismometer.configuration.model import FilterConfig
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Cat": ["A", "B", "C", "D", "E"]})
+ config = FilterConfig(source="Cat", action="exclude", values=["A", "B"])
+ rule = FilterRule.from_filter_config(config)
+
+ assert rule.relation == "notin"
+ assert rule.left == "Cat"
+
+ result = rule.filter(df)
+ assert sorted(result["Cat"].unique()) == ["C", "D", "E"]
+
+ def test_from_filter_config_include_with_range_both_bounds(self, monkeypatch):
+ """Test action='include' with range (min and max) creates compound rule."""
+ from seismometer.configuration.model import FilterConfig, FilterRange
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Val": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})
+ config = FilterConfig(source="Val", action="include", range=FilterRange(min=3, max=8))
+ rule = FilterRule.from_filter_config(config)
+
+ result = rule.filter(df)
+ assert list(result["Val"]) == [3, 4, 5, 6, 7] # min inclusive, max exclusive
+
+ def test_from_filter_config_include_with_range_min_only(self, monkeypatch):
+ """Test action='include' with range (min only) creates >= rule."""
+ from seismometer.configuration.model import FilterConfig, FilterRange
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Val": [1, 2, 3, 4, 5]})
+ config = FilterConfig(source="Val", action="include", range=FilterRange(min=3))
+ rule = FilterRule.from_filter_config(config)
+
+ result = rule.filter(df)
+ assert list(result["Val"]) == [3, 4, 5]
+
+ def test_from_filter_config_include_with_range_max_only(self, monkeypatch):
+ """Test action='include' with range (max only) creates < rule."""
+ from seismometer.configuration.model import FilterConfig, FilterRange
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Val": [1, 2, 3, 4, 5]})
+ config = FilterConfig(source="Val", action="include", range=FilterRange(max=3))
+ rule = FilterRule.from_filter_config(config)
+
+ result = rule.filter(df)
+ assert list(result["Val"]) == [1, 2]
+
+ def test_from_filter_config_exclude_with_range(self, monkeypatch):
+ """Test action='exclude' with range negates the range rule."""
+ from seismometer.configuration.model import FilterConfig, FilterRange
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Val": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})
+ config = FilterConfig(source="Val", action="exclude", range=FilterRange(min=3, max=8))
+ rule = FilterRule.from_filter_config(config)
+
+ result = rule.filter(df)
+ # Should exclude [3,4,5,6,7], keep [1,2,8,9,10]
+ assert list(result["Val"]) == [1, 2, 8, 9, 10]
+
+ def test_from_filter_config_invalid_action_raises(self):
+ """Test invalid action raises ValidationError from Pydantic."""
+ from pydantic import ValidationError
+
+ from seismometer.configuration.model import FilterConfig
+
+ # Pydantic validates action field, so invalid values raise ValidationError at creation
+ with pytest.raises(ValidationError, match="Input should be 'include', 'exclude' or 'keep_top'"):
+ FilterConfig(source="Col", action="invalid_action")
+
+ def test_from_filter_config_topk_with_none_maximum_returns_all(self, monkeypatch):
+ """Test keep_top with MAXIMUM_NUM_COHORTS=None and count=None returns all() rule."""
+ from seismometer.configuration.model import FilterConfig
+
+ monkeypatch.setattr(FilterRule, "MAXIMUM_NUM_COHORTS", None)
+ config = FilterConfig(source="Cat", action="keep_top", count=None)
+ rule = FilterRule.from_filter_config(config)
+
+ assert rule == FilterRule.all()
+
+ def test_from_filter_config_list_with_none(self):
+ """Test from_filter_config_list with None returns all() rule."""
+ rule = FilterRule.from_filter_config_list(None)
+ assert rule == FilterRule.all()
+
+ def test_from_filter_config_list_with_empty_list(self):
+ """Test from_filter_config_list with empty list returns all() rule."""
+ rule = FilterRule.from_filter_config_list([])
+ assert rule == FilterRule.all()
+
+ def test_from_filter_config_list_with_single_config(self, monkeypatch):
+ """Test from_filter_config_list with single config creates that rule."""
+ from seismometer.configuration.model import FilterConfig
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Cat": ["A", "B", "C"]})
+ config = FilterConfig(source="Cat", action="include", values=["A"])
+ rule = FilterRule.from_filter_config_list([config])
+
+ result = rule.filter(df)
+ assert list(result["Cat"]) == ["A"]
+
+ def test_from_filter_config_list_with_multiple_configs(self, monkeypatch):
+ """Test from_filter_config_list with multiple configs combines with AND logic."""
+ from seismometer.configuration.model import FilterConfig, FilterRange
+
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Cat": ["A", "A", "B", "B", "C"], "Val": [1, 5, 2, 6, 3]})
+ config1 = FilterConfig(source="Cat", action="include", values=["A", "B"])
+ config2 = FilterConfig(source="Val", action="include", range=FilterRange(min=2, max=6))
+ rule = FilterRule.from_filter_config_list([config1, config2])
+
+ result = rule.filter(df)
+ # Should keep rows where Cat in ["A","B"] AND Val in [2,6)
+ # That's: B,2 and A,5
+ assert len(result) == 2
+ assert set(result["Cat"]) == {"A", "B"}
+ assert all((result["Val"] >= 2) & (result["Val"] < 6))
+
class TestFilterRuleCombinationLogic:
@pytest.mark.parametrize(
@@ -427,3 +564,158 @@ def test_matches_cohort(self):
def test_matches_default_cohort(self):
rule = filter_rule_from_cohort_dictionary()
assert rule == FilterRule.all()
+
+
+class TestHelperFunctions:
+ """Test helper functions that are exported but not directly tested elsewhere."""
+
+ def test_apply_column_comparison_error_handling(self):
+ """Test apply_column_comparison error handling with incompatible types."""
+ from seismometer.data.filter import apply_column_comparison
+
+ df = pd.DataFrame({"Col": ["a", "b", "c"]})
+
+ # String column compared with integer should raise ValueError
+ with pytest.raises(ValueError, match="Values in 'Col' must be comparable to '5'"):
+ apply_column_comparison(df, "Col", 5, "<")
+
+ def test_apply_column_comparison_with_valid_comparison(self):
+ """Test apply_column_comparison works with valid comparisons."""
+ from seismometer.data.filter import apply_column_comparison
+
+ df = pd.DataFrame({"Val": [1, 2, 3, 4, 5]})
+ result = apply_column_comparison(df, "Val", 3, "<")
+
+ assert result.equals(df["Val"] < 3)
+ assert result.sum() == 2 # Only 1 and 2 are < 3
+
+ def test_apply_topk_filter_with_k_greater_than_unique_values(self):
+ """Test topk with k > number of unique values returns all True mask."""
+ from seismometer.data.filter import apply_topk_filter
+
+ df = pd.DataFrame({"Cat": ["A", "A", "B", "B", "C"]})
+ # Only 3 unique values, but ask for top 5
+ mask = apply_topk_filter(df, "Cat", 5)
+
+ assert isinstance(mask, pd.Series)
+ assert mask.all() # All rows should be True
+ assert len(mask) == 5
+
+ def test_apply_topk_filter_with_k_equal_to_unique_values(self):
+ """Test topk with k == number of unique values returns all True mask."""
+ from seismometer.data.filter import apply_topk_filter
+
+ df = pd.DataFrame({"Cat": ["A", "A", "B", "B", "C"]})
+ # Exactly 3 unique values, ask for top 3
+ mask = apply_topk_filter(df, "Cat", 3)
+
+ assert isinstance(mask, pd.Series)
+ assert mask.all() # All rows should be True
+
+ def test_apply_topk_filter_with_single_unique_value(self):
+ """Test topk with DataFrame containing single unique value."""
+ from seismometer.data.filter import apply_topk_filter
+
+ df = pd.DataFrame({"Cat": ["A", "A", "A", "A"]})
+ mask = apply_topk_filter(df, "Cat", 2)
+
+ assert isinstance(mask, pd.Series)
+ assert mask.all() # All rows should be True since only one unique value
+ assert len(mask) == 4
+
+ def test_apply_topk_filter_with_empty_dataframe(self):
+ """Test topk with empty DataFrame doesn't crash."""
+ from seismometer.data.filter import apply_topk_filter
+
+ df = pd.DataFrame({"Cat": []})
+ mask = apply_topk_filter(df, "Cat", 2)
+
+ assert isinstance(mask, pd.Series)
+ assert len(mask) == 0
+
+
+class TestEdgeCases:
+ """Test edge cases and error conditions not covered elsewhere."""
+
+ def test_filter_with_missing_column_raises_keyerror(self):
+ """Test filtering with non-existent column raises KeyError."""
+ df = pd.DataFrame({"Col1": [1, 2, 3]})
+ rule = FilterRule("NonExistentColumn", "==", 1)
+
+ with pytest.raises(KeyError):
+ rule.filter(df)
+
+ def test_mask_with_missing_column_raises_keyerror(self):
+ """Test mask with non-existent column raises KeyError."""
+ df = pd.DataFrame({"Col1": [1, 2, 3]})
+ rule = FilterRule("NonExistentColumn", "==", 1)
+
+ with pytest.raises(KeyError):
+ rule.mask(df)
+
+ def test_filter_with_empty_dataframe(self, monkeypatch):
+ """Test filter on empty DataFrame returns empty DataFrame."""
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Col": []})
+ rule = FilterRule("Col", "==", 1)
+
+ result = rule.filter(df)
+ assert len(result) == 0
+ assert list(result.columns) == ["Col"]
+
+ def test_filter_with_single_row_dataframe(self, monkeypatch):
+ """Test filter on single-row DataFrame works correctly."""
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", None)
+ df = pd.DataFrame({"Col": [1]})
+ rule = FilterRule("Col", "==", 1)
+
+ result = rule.filter(df)
+ assert len(result) == 1
+ assert result["Col"].iloc[0] == 1
+
+ def test_str_with_numeric_isin_values(self):
+ """Test __str__ with numeric isin values doesn't crash."""
+ rule = FilterRule("Col", "isin", [1, 2, 3])
+
+ # Should not raise AttributeError from .join()
+ result = str(rule)
+ assert "Col" in result
+ assert "is in" in result
+
+ def test_str_with_mixed_type_isin_values(self):
+ """Test __str__ with mixed type isin values doesn't crash."""
+ rule = FilterRule("Col", "isin", [1, "A", 2.5])
+
+ # Should not raise AttributeError
+ result = str(rule)
+ assert "Col" in result
+
+ def test_min_rows_boundary_exact_equal(self, monkeypatch):
+ """Test MIN_ROWS boundary when len(df) == MIN_ROWS returns empty (exclusive threshold)."""
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", 10)
+ df = pd.DataFrame({"Col": [1] * 10}) # Exactly 10 rows
+ rule = FilterRule("Col", "==", 1)
+
+ result = rule.filter(df)
+ # MIN_ROWS uses > comparison, so len == MIN_ROWS returns empty
+ assert len(result) == 0
+
+ def test_min_rows_boundary_just_below(self, monkeypatch):
+ """Test MIN_ROWS boundary when len(df) < MIN_ROWS returns empty."""
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", 10)
+ df = pd.DataFrame({"Col": [1] * 9}) # 9 rows, below threshold
+ rule = FilterRule("Col", "==", 1)
+
+ result = rule.filter(df)
+ # Should return empty because below MIN_ROWS
+ assert len(result) == 0
+
+ def test_min_rows_boundary_just_above(self, monkeypatch):
+ """Test MIN_ROWS boundary when len(df) > MIN_ROWS returns data."""
+ monkeypatch.setattr(FilterRule, "MIN_ROWS", 10)
+ df = pd.DataFrame({"Col": [1] * 11}) # 11 rows, above threshold
+ rule = FilterRule("Col", "==", 1)
+
+ result = rule.filter(df)
+ # Should return the filtered result (11 rows match)
+ assert len(result) == 11
diff --git a/tests/data/test_pandas_helpers.py b/tests/data/test_pandas_helpers.py
index 1b44c13c..95e5bfff 100644
--- a/tests/data/test_pandas_helpers.py
+++ b/tests/data/test_pandas_helpers.py
@@ -128,6 +128,65 @@ def test_merge_strategies_do_not_generate_additional_rows(self, strategy):
assert "MyEvent_Time" in actual.columns
assert len(actual) == len(preds)
+ def test_merge_with_strategy_empty_pks_raises(self):
+ """Empty pks list should cause merge_asof to fail (needs by parameter)."""
+ preds = pd.DataFrame(
+ {
+ "Id": [1, 2],
+ "PredictTime": [pd.Timestamp("2024-01-01 01:00:00"), pd.Timestamp("2024-01-01 02:00:00")],
+ }
+ )
+ events = pd.DataFrame(
+ {
+ "Id": [1, 2],
+ "Time": [pd.Timestamp("2024-01-01 03:00:00"), pd.Timestamp("2024-01-01 04:00:00")],
+ "Value": [10, 20],
+ "Type": ["MyEvent", "MyEvent"],
+ }
+ )
+
+ one_event = undertest._one_event(events, "MyEvent", "Value", "Time", [])
+
+ # With empty pks, merge_asof will fail because it needs a by parameter
+ with pytest.raises((ValueError, KeyError, IndexError)):
+ undertest._merge_with_strategy(
+ predictions=preds,
+ one_event=one_event,
+ pks=[], # Empty pks list - causes error
+ pred_ref="PredictTime",
+ event_ref="MyEvent_Time",
+ merge_strategy="forward",
+ )
+
+ def test_merge_with_strategy_all_nat_event_times(self):
+ """All NaT event times should trigger warning and use first row logic."""
+ preds = pd.DataFrame({"Id": [1], "PredictTime": [pd.Timestamp("2024-01-01 01:00:00")]})
+ events = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Time": [pd.NaT, pd.NaT], # All NaT
+ "Value": [10, 20],
+ "Type": ["MyEvent", "MyEvent"],
+ }
+ )
+
+ one_event = undertest._one_event(events, "MyEvent", "Value", "Time", ["Id"])
+
+ # All NaT should trigger the ct_times == 0 path
+ result = undertest._merge_with_strategy(
+ predictions=preds,
+ one_event=one_event,
+ pks=["Id"],
+ pred_ref="PredictTime",
+ event_ref="MyEvent_Time",
+ merge_strategy="forward",
+ )
+
+ # Should merge with first row (groupby.first())
+ assert len(result) == 1
+ assert "MyEvent_Value" in result.columns
+ assert result["MyEvent_Value"].iloc[0] == 10 # First value
+
def infer_cases():
return pd.DataFrame(
@@ -230,6 +289,58 @@ def test_imputation_overrides(self):
pdt.assert_frame_equal(actual, expect, check_dtype=False)
+ def test_empty_dataframe_returns_unchanged(self):
+ """Empty DataFrame should be returned unchanged."""
+ df = pd.DataFrame({"Label": [], "Time": []})
+ result = undertest.post_process_event(df, "Label", "Time")
+ pdt.assert_frame_equal(result, df, check_dtype=False)
+
+ def test_all_nat_times_imputes_no_time(self):
+ """When all times are NaT, should impute with no_time value."""
+ df = pd.DataFrame({"Label": [None, None, None], "Time": [pd.NaT, pd.NaT, pd.NaT]})
+ result = undertest.post_process_event(df, "Label", "Time")
+ assert (result["Label"] == 0).all()
+
+ def test_both_impute_values_none_no_imputation(self):
+ """When both impute values are None, no imputation should occur."""
+ df = pd.DataFrame({"Label": [None, 1, None], "Time": [pd.Timestamp.now(), pd.NaT, pd.NaT]})
+ result = undertest.post_process_event(df, "Label", "Time", impute_val_with_time=None, impute_val_no_time=None)
+ assert result["Label"].iloc[0] is pd.NA or pd.isna(result["Label"].iloc[0])
+ assert result["Label"].iloc[1] == 1
+ assert result["Label"].iloc[2] is pd.NA or pd.isna(result["Label"].iloc[2])
+
+ @pytest.mark.parametrize(
+ "impute_with,impute_no,expected_with,expected_no",
+ [
+ (-1, -2, -1, -2), # Negative values
+ (100, 200, 100, 200), # Large positive values
+ (0.5, 0.1, 0.5, 0.1), # Decimal values
+ ("yes", "no", "yes", "no"), # String values
+ ],
+ ids=["negative", "large_positive", "decimal", "string"],
+ )
+ def test_impute_values_various_types(self, impute_with, impute_no, expected_with, expected_no):
+ """Test imputation with various value types."""
+ now = pd.Timestamp.now()
+ df = pd.DataFrame({"Label": [None, None], "Time": [now, pd.NaT]})
+ result = undertest.post_process_event(
+ df, "Label", "Time", impute_val_with_time=impute_with, impute_val_no_time=impute_no, column_dtype=None
+ )
+ assert result["Label"].iloc[0] == expected_with
+ assert result["Label"].iloc[1] == expected_no
+
+ def test_single_row_dataframe(self):
+ """Single row DataFrame should work correctly."""
+ df = pd.DataFrame({"Label": [None], "Time": [pd.Timestamp.now()]})
+ result = undertest.post_process_event(df, "Label", "Time")
+ assert result["Label"].iloc[0] == 1
+
+ def test_missing_columns_returns_unchanged(self):
+ """Missing columns should return DataFrame unchanged."""
+ df = pd.DataFrame({"A": [1], "B": [2]})
+ result = undertest.post_process_event(df, "MissingLabel", "MissingTime")
+ pdt.assert_frame_equal(result, df)
+
BASE_STRINGS = [
("A"),
@@ -283,6 +394,32 @@ def test_one_event_filters_and_renames(self):
assert "Target_Time" in result.columns
assert len(result) == 1
+ def test_one_event_missing_type_column_raises(self):
+ """Missing Type column should raise AttributeError."""
+ events = pd.DataFrame({"Id": [1], "Value": [10], "Time": [pd.Timestamp.now()]})
+ with pytest.raises(AttributeError, match="Type"):
+ undertest._one_event(events, "Target", "Value", "Time", ["Id"])
+
+ def test_one_event_missing_value_column_raises(self):
+ """Missing value column should raise KeyError."""
+ events = pd.DataFrame({"Id": [1], "Type": ["Target"], "Time": [pd.Timestamp.now()]})
+ with pytest.raises(KeyError, match="Value"):
+ undertest._one_event(events, "Target", "Value", "Time", ["Id"])
+
+ def test_one_event_missing_time_column_raises(self):
+ """Missing time column should raise KeyError."""
+ events = pd.DataFrame({"Id": [1], "Type": ["Target"], "Value": [10]})
+ with pytest.raises(KeyError, match="Time"):
+ undertest._one_event(events, "Target", "Value", "Time", ["Id"])
+
+ def test_one_event_no_matching_event_returns_empty(self):
+ """No matching event type should return empty DataFrame with correct columns."""
+ events = pd.DataFrame({"Id": [1], "Type": ["OtherEvent"], "Value": [10], "Time": [pd.Timestamp.now()]})
+ result = undertest._one_event(events, "Target", "Value", "Time", ["Id"])
+ assert len(result) == 0
+ assert "Target_Value" in result.columns
+ assert "Target_Time" in result.columns
+
class TestEventTime:
@pytest.mark.parametrize(
@@ -341,11 +478,41 @@ def test_three_suffixes_align(self, base, suffix):
("all caps ending in _VALUE", "all caps ending in _VALUE"),
("only one suffix gets stripped_Time_Value", "only one suffix gets stripped_Time"),
("only one suffix gets stripped_Value_Time", "only one suffix gets stripped_Value"),
+ ("", ""), # Empty string
+ ("_", "_"), # Just underscore
+ ("__Time", "_"), # Multiple underscores before suffix
+ ("event__Value", "event_"), # Multiple underscores in name
+ ],
+ ids=[
+ "value_suffix",
+ "time_suffix",
+ "no_underscore_time",
+ "no_underscore_value",
+ "lowercase_time",
+ "lowercase_value",
+ "uppercase_time",
+ "uppercase_value",
+ "double_suffix_time_first",
+ "double_suffix_value_first",
+ "empty_string",
+ "just_underscore",
+ "double_underscore_time",
+ "double_underscore_value",
],
)
def test_suffix_specific_handling(self, input, expected):
assert expected == undertest.event_name(input)
+ def test_very_long_event_name(self):
+ """Very long event names should work correctly."""
+ long_name = "a" * 1000 + "_Time"
+ assert undertest.event_name(long_name) == "a" * 1000
+
+ def test_unicode_characters(self):
+ """Unicode characters should be preserved."""
+ assert undertest.event_name("événement_Time") == "événement"
+ assert undertest.event_name("事件_Value") == "事件"
+
class TestEventHelpers:
@pytest.mark.parametrize(
@@ -353,7 +520,11 @@ class TestEventHelpers:
[
("MyEvent", "Critical", "MyEvent~Critical_Count"),
("MyEvent_Value", "High_Count", "MyEvent~High_Count"),
+ ("", "", "~_Count"), # Empty strings
+ ("Event", "Val~ue", "Event~Val~ue_Count"), # Tilde in value
+ ("A", "1", "A~1_Count"), # Single char
],
+ ids=["standard", "with_high_count", "empty_strings", "tilde_in_value", "single_char"],
)
def test_event_value_count(self, event_label, event_value, expected):
assert undertest.event_value_count(event_label, event_value) == expected
@@ -364,7 +535,11 @@ def test_event_value_count(self, event_label, event_value, expected):
("MyEvent~Critical_Count", "Critical"),
("MyEvent~123_Count", "123"),
("Event_Only_Count", "Event_Only"), # no ~
+ ("Multi~Tilde~Value_Count", "Tilde"), # Multiple tildes - split()[1] gets second element
+ ("NoCountSuffix", "NoCountSuffix"), # No _Count suffix
+ ("", ""), # Empty string
],
+ ids=["standard", "numeric", "no_tilde", "multi_tilde", "no_count", "empty"],
)
def test_event_value_name(self, input, expected):
assert undertest.event_value_name(input) == expected
@@ -390,6 +565,34 @@ def test_is_valid_event_when_invalid(self):
result = undertest.is_valid_event(df, "MyEvent", "RefTime")
assert not result.any()
+ def test_is_valid_event_mixed_valid_invalid(self):
+ """Test with mixed valid/invalid events."""
+ now = pd.Timestamp.now()
+ df = pd.DataFrame(
+ {
+ "MyEvent_Time": [now + pd.Timedelta(hours=1), now - pd.Timedelta(hours=1)],
+ "RefTime": [now, now],
+ }
+ )
+ result = undertest.is_valid_event(df, "MyEvent", "RefTime")
+ assert result.iloc[0] # First is valid
+ assert not result.iloc[1] # Second is invalid
+
+ def test_is_valid_event_with_nat(self):
+ """Test with NaT values in event times."""
+ now = pd.Timestamp.now()
+ df = pd.DataFrame({"MyEvent_Time": [pd.NaT, now + pd.Timedelta(hours=1)], "RefTime": [now, now]})
+ result = undertest.is_valid_event(df, "MyEvent", "RefTime")
+ # NaT comparisons return False
+ assert not result.iloc[0]
+ assert result.iloc[1]
+
+ def test_is_valid_event_empty_dataframe(self):
+ """Empty DataFrame should return empty Series."""
+ df = pd.DataFrame({"MyEvent_Time": [], "RefTime": []})
+ result = undertest.is_valid_event(df, "MyEvent", "RefTime")
+ assert len(result) == 0
+
class TestTryCasting:
@pytest.mark.parametrize(
@@ -399,6 +602,7 @@ class TestTryCasting:
("float", "float64"),
("string", "string"),
("object", "object"),
+ ("Int64", "Int64"), # Nullable integer
],
)
def test_try_casting_valid_types(self, dtype, expected_type):
@@ -411,6 +615,47 @@ def test_try_casting_invalid_type_raises(self):
with pytest.raises(undertest.ConfigurationError, match="Cannot cast 'col' values to 'int'."):
undertest.try_casting(df, "col", "int")
+ def test_try_casting_empty_dataframe(self):
+ """Empty DataFrame should cast successfully."""
+ df = pd.DataFrame({"col": pd.Series([], dtype=object)})
+ undertest.try_casting(df, "col", "int")
+ assert df["col"].dtype.name == "int64"
+
+ def test_try_casting_single_row(self):
+ """Single row should cast successfully."""
+ df = pd.DataFrame({"col": ["42"]})
+ undertest.try_casting(df, "col", "int")
+ assert df["col"].iloc[0] == 42
+
+ def test_try_casting_with_nulls_to_int64(self):
+ """Nullable Int64 should handle None values."""
+ df = pd.DataFrame({"col": [1, None, 3]})
+ undertest.try_casting(df, "col", "Int64")
+ assert df["col"].dtype.name == "Int64"
+ assert pd.isna(df["col"].iloc[1])
+
+ def test_try_casting_float_strings_to_int(self):
+ """Float strings should cast to int via float intermediate."""
+ df = pd.DataFrame({"col": ["1.0", "2.0", "3.0"]})
+ undertest.try_casting(df, "col", "int")
+ assert df["col"].dtype.name == "int64"
+ assert (df["col"] == [1, 2, 3]).all()
+
+ @pytest.mark.parametrize(
+ "input_data,dtype",
+ [
+ (["not", "a", "number"], "int"),
+ (["1.5.5"], "float"),
+ (["2023-13-45"], "datetime64"), # Invalid date
+ ],
+ ids=["string_to_int", "malformed_float", "invalid_datetime"],
+ )
+ def test_try_casting_raises_configuration_error(self, input_data, dtype):
+ """Various invalid casts should raise ConfigurationError."""
+ df = pd.DataFrame({"col": input_data})
+ with pytest.raises(undertest.ConfigurationError):
+ undertest.try_casting(df, "col", dtype)
+
class TestResolveHelpers:
def test_resolve_time_col_from_event_suffix(self):
@@ -440,6 +685,343 @@ def test_resolve_score_col_raises_if_missing(self):
undertest._resolve_score_col(df, "MyScore")
+class TestAggregationFunctions:
+ """Direct tests for aggregation functions used by event_score."""
+
+ def test_max_aggregation_picks_highest_score_with_positive_event(self):
+ """max_aggregation should pick row with highest score among positive events."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1],
+ "Score": [0.3, 0.8, 0.5],
+ "EventName_Value": [1, 1, 0], # First two are positive
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 3,
+ }
+ )
+
+ result = undertest.max_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.8 # Highest score among positive events
+
+ def test_max_aggregation_requires_ref_event(self):
+ """max_aggregation should raise ValueError if ref_event is None."""
+ df = pd.DataFrame({"Id": [1], "Score": [0.5]})
+
+ with pytest.raises(ValueError, match="ref_event is required"):
+ undertest.max_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event=None)
+
+ def test_min_aggregation_picks_lowest_score_with_positive_event(self):
+ """min_aggregation should pick row with lowest score among positive events."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1],
+ "Score": [0.3, 0.8, 0.5],
+ "EventName_Value": [1, 1, 0], # First two are positive
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 3,
+ }
+ )
+
+ result = undertest.min_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.3 # Lowest score among positive events
+
+ def test_min_aggregation_requires_ref_event(self):
+ """min_aggregation should raise ValueError if ref_event is None."""
+ df = pd.DataFrame({"Id": [1], "Score": [0.5]})
+
+ with pytest.raises(ValueError, match="ref_event is required"):
+ undertest.min_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event=None)
+
+ def test_first_aggregation_picks_earliest_by_time(self):
+ """first_aggregation should pick row with earliest timestamp."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1],
+ "Score": [0.3, 0.8, 0.5],
+ "EventName_Time": [
+ pd.Timestamp("2024-01-03"),
+ pd.Timestamp("2024-01-01"), # Earliest
+ pd.Timestamp("2024-01-02"),
+ ],
+ }
+ )
+
+ result = undertest.first_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.8 # Score from earliest timestamp
+
+ def test_first_aggregation_requires_ref_time(self):
+ """first_aggregation should raise ValueError if ref_time is None."""
+ df = pd.DataFrame({"Id": [1], "Score": [0.5]})
+
+ with pytest.raises(ValueError, match="ref_time is required"):
+ undertest.first_aggregation(df, pks=["Id"], score="Score", ref_time=None, ref_event="EventName")
+
+ def test_first_aggregation_drops_nat_timestamps(self):
+ """first_aggregation should drop rows with NaT timestamps."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1],
+ "Score": [0.3, 0.8, 0.5],
+ "EventName_Time": [
+ pd.NaT, # Should be dropped
+ pd.Timestamp("2024-01-01"),
+ pd.Timestamp("2024-01-02"),
+ ],
+ }
+ )
+
+ result = undertest.first_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.8 # First non-NaT timestamp
+
+ def test_last_aggregation_picks_latest_by_time(self):
+ """last_aggregation should pick row with latest timestamp."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1],
+ "Score": [0.3, 0.8, 0.5],
+ "EventName_Time": [
+ pd.Timestamp("2024-01-01"),
+ pd.Timestamp("2024-01-02"),
+ pd.Timestamp("2024-01-03"), # Latest
+ ],
+ }
+ )
+
+ result = undertest.last_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.5 # Score from latest timestamp
+
+ def test_last_aggregation_requires_ref_time(self):
+ """last_aggregation should raise ValueError if ref_time is None."""
+ df = pd.DataFrame({"Id": [1], "Score": [0.5]})
+
+ with pytest.raises(ValueError, match="ref_time is required"):
+ undertest.last_aggregation(df, pks=["Id"], score="Score", ref_time=None, ref_event="EventName")
+
+ @pytest.mark.parametrize(
+ "agg_func,sort_by,expected_score",
+ [
+ (undertest.max_aggregation, None, 0.9), # Picks highest score
+ (undertest.min_aggregation, None, 0.1), # Picks lowest score
+ (undertest.first_aggregation, "time", 0.3), # Picks earliest time
+ (undertest.last_aggregation, "time", 0.7), # Picks latest time
+ ],
+ ids=["max", "min", "first", "last"],
+ )
+ def test_aggregation_functions_with_multiple_entities(self, agg_func, sort_by, expected_score):
+ """All aggregation functions should work correctly with multiple entities."""
+ if sort_by == "time":
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 2, 2],
+ "Score": [0.3, 0.7, 0.4, 0.6],
+ "EventName_Time": [
+ pd.Timestamp("2024-01-01"), # Earliest for Id=1
+ pd.Timestamp("2024-01-02"), # Latest for Id=1
+ pd.Timestamp("2024-01-01"),
+ pd.Timestamp("2024-01-02"),
+ ],
+ }
+ )
+ result = agg_func(df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName")
+ else:
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 2, 2],
+ "Score": [0.1, 0.9, 0.2, 0.8],
+ "EventName_Value": [1, 1, 1, 1],
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 4,
+ }
+ )
+ result = agg_func(df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName")
+
+ # Should have 2 rows (one per entity)
+ assert len(result) == 2
+ # Check Id=1 has expected score
+ assert result[result["Id"] == 1]["Score"].iloc[0] == expected_score
+
+ def test_max_aggregation_all_negative_events(self):
+ """max_aggregation with all negative events should still return a row."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Score": [0.3, 0.8],
+ "EventName_Value": [0, 0], # All negative
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 2,
+ }
+ )
+ result = undertest.max_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event="EventName")
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.8 # Still picks max even if all negative
+
+ def test_min_aggregation_all_negative_events(self):
+ """min_aggregation with all negative events should still return a row."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Score": [0.3, 0.8],
+ "EventName_Value": [0, 0], # All negative
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 2,
+ }
+ )
+ result = undertest.min_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event="EventName")
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.3 # Still picks min even if all negative
+
+ def test_aggregation_with_identical_scores(self):
+ """When scores are identical, should return one row per pk."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1],
+ "Score": [0.5, 0.5, 0.5], # All same
+ "EventName_Value": [1, 1, 1],
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 3,
+ }
+ )
+ result = undertest.max_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event="EventName")
+ assert len(result) == 1
+
+ def test_aggregation_with_inf_values(self):
+ """Aggregation should handle inf/-inf values."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Score": [float("inf"), 0.5],
+ "EventName_Value": [1, 1],
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 2,
+ }
+ )
+ result = undertest.max_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event="EventName")
+ assert result["Score"].iloc[0] == float("inf")
+
+ def test_first_last_aggregation_with_identical_times(self):
+ """When times are identical, first/last should still return one row."""
+ same_time = pd.Timestamp("2024-01-01")
+ df = pd.DataFrame({"Id": [1, 1], "Score": [0.3, 0.7], "EventName_Time": [same_time, same_time]})
+ result_first = undertest.first_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+ result_last = undertest.last_aggregation(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName"
+ )
+ assert len(result_first) == 1
+ assert len(result_last) == 1
+
+ def test_aggregation_empty_dataframe(self):
+ """Empty DataFrame should return empty result."""
+ df = pd.DataFrame({"Id": [], "Score": [], "EventName_Value": [], "EventName_Time": []})
+ result = undertest.max_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event="EventName")
+ assert len(result) == 0
+
+ def test_max_aggregation_missing_score_column_raises(self):
+ """Missing score column should raise clear ValueError."""
+ df = pd.DataFrame({"Id": [1], "EventName_Value": [1], "EventName_Time": [pd.Timestamp.now()]})
+ with pytest.raises(ValueError, match="NonExistentScore"):
+ undertest.max_aggregation(
+ df, pks=["Id"], score="NonExistentScore", ref_time="EventName_Time", ref_event="EventName"
+ )
+
+ def test_first_aggregation_missing_ref_time_column_raises(self):
+ """Missing ref_time column should raise clear error."""
+ df = pd.DataFrame({"Id": [1], "Score": [0.5]})
+ # First check ValueError for None
+ with pytest.raises(ValueError, match="ref_time is required"):
+ undertest.first_aggregation(df, pks=["Id"], score="Score", ref_time=None, ref_event="EventName")
+
+ # Then check what happens when ref_time doesn't exist in DataFrame
+ with pytest.raises(ValueError, match="Reference time column .* not found"):
+ undertest.first_aggregation(df, pks=["Id"], score="Score", ref_time="NonExistent", ref_event="EventName")
+
+ def test_aggregation_duplicate_pks_keeps_first_after_sort(self):
+ """With duplicate pks, aggregation should keep first after sorting."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 1], # Duplicates
+ "Score": [0.3, 0.9, 0.5],
+ "EventName_Value": [1, 1, 1],
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 3,
+ }
+ )
+
+ # max_aggregation sorts by EventName_Value (desc), Score (desc), then drops duplicates
+ result = undertest.max_aggregation(df, pks=["Id"], score="Score", ref_time="Time", ref_event="EventName")
+
+ # Should keep highest score (0.9)
+ assert len(result) == 1
+ assert result["Score"].iloc[0] == 0.9
+
+
+class TestAnalyticsMetricName:
+ """Tests for analytics_metric_name utility function."""
+
+ def test_returns_column_name_if_in_metric_names(self):
+ """If column_name is already in metric_names, return it unchanged."""
+ metric_names = ["accuracy", "precision", "recall"]
+ result = undertest.analytics_metric_name(metric_names, [], "accuracy")
+ assert result == "accuracy"
+
+ def test_strips_prefix_if_matches_existing_metric_starts(self):
+ """If column starts with metric prefix, strip it."""
+ metric_names = []
+ existing_starts = ["model_v1", "model_v2"]
+ result = undertest.analytics_metric_name(metric_names, existing_starts, "model_v1_accuracy")
+ assert result == "accuracy"
+
+ def test_returns_none_if_no_match(self):
+ """If no match found, return None."""
+ metric_names = ["accuracy"]
+ existing_starts = ["model_v1"]
+ result = undertest.analytics_metric_name(metric_names, existing_starts, "unknown_metric")
+ assert result is None
+
+ @pytest.mark.parametrize(
+ "metric_names,existing_starts,column_name,expected",
+ [
+ (["accuracy"], [], "accuracy", "accuracy"), # Direct match
+ ([], ["model"], "model_accuracy", "accuracy"), # Prefix strip
+ ([], ["v1", "v2"], "v1_precision", "precision"), # First prefix match
+ ([], ["v1", "v2"], "v2_recall", "recall"), # Second prefix match
+ (["score"], ["model"], "score", "score"), # Direct match takes precedence
+ ([], [], "metric", None), # No match
+ ([], ["prefix"], "other_metric", None), # Wrong prefix
+ ([], ["model"], "model_model", "model"), # Repeated prefix chars
+ ([], ["model"], "mode_accuracy", None), # Similar but doesn't start with prefix
+ ],
+ ids=[
+ "direct_match",
+ "prefix_strip",
+ "first_prefix",
+ "second_prefix",
+ "direct_over_prefix",
+ "no_match",
+ "wrong_prefix",
+ "repeated_prefix_chars",
+ "similar_not_prefix",
+ ],
+ )
+ def test_analytics_metric_name_various_cases(self, metric_names, existing_starts, column_name, expected):
+ """Test various scenarios for analytics_metric_name."""
+ result = undertest.analytics_metric_name(metric_names, existing_starts, column_name)
+ assert result == expected
+
+
class TestEventScoreAndModelScores:
@pytest.mark.parametrize("method", ["max", "min", "first", "last"])
def test_event_score_valid_methods(self, method):
@@ -511,6 +1093,60 @@ def test_get_model_scores_bypass_when_not_per_context(self):
)
pd.testing.assert_frame_equal(result, df)
+ def test_event_score_empty_dataframe(self):
+ """Empty DataFrame should return empty result."""
+ df = pd.DataFrame({"Id": [], "Score": [], "EventName_Value": [], "EventName_Time": []})
+ result = undertest.event_score(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName", aggregation_method="max"
+ )
+ assert len(result) == 0
+
+ def test_event_score_pks_not_in_dataframe(self):
+ """When pks don't exist, should filter to available columns."""
+ df = pd.DataFrame({"Id": [1], "Score": [0.5], "EventName_Value": [1], "EventName_Time": [pd.Timestamp.now()]})
+ result = undertest.event_score(
+ df,
+ pks=["Id", "NonExistent"],
+ score="Score",
+ ref_time="EventName_Time",
+ ref_event="EventName",
+ aggregation_method="max",
+ )
+ assert len(result) == 1
+
+ def test_get_model_scores_empty_dataframe(self):
+ """Empty DataFrame should return empty when per_context_id=True."""
+ df = pd.DataFrame({"Id": [], "Score": [], "Event_Value": [], "Event_Time": []})
+ result = undertest.get_model_scores(
+ df,
+ entity_keys=["Id"],
+ score_col="Score",
+ ref_time="Event_Time",
+ ref_event="Event",
+ aggregation_method="max",
+ per_context_id=True,
+ )
+ assert len(result) == 0
+
+ def test_event_score_all_nan_scores(self):
+ """All NaN scores should return empty result after filtering."""
+ df = pd.DataFrame(
+ {
+ "Id": [1, 1, 2],
+ "Score": [float("nan"), float("nan"), float("nan")], # All NaN
+ "EventName_Value": [1, 1, 1],
+ "EventName_Time": [pd.Timestamp("2024-01-01")] * 3,
+ }
+ )
+
+ result = undertest.event_score(
+ df, pks=["Id"], score="Score", ref_time="EventName_Time", ref_event="EventName", aggregation_method="max"
+ )
+
+ # After aggregation and filtering NaN indices, should return empty or rows with NaN scores
+ # The function filters out NaN indices with ~np.isnan(df.index)
+ assert isinstance(result, pd.DataFrame)
+
class TestMergeEventCounts:
def test_skips_time_filter_when_window_none(self, base_counts_data):
@@ -637,6 +1273,128 @@ def test_counts_respect_min_offset(self, base_counts_data):
assert result["Label~A_Count"].iloc[1] == 0
assert result["Label~B_Count"].iloc[1] == 0
+ def test_merge_event_counts_empty_left_returns_empty(self):
+ """Empty left DataFrame should return empty."""
+ preds = pd.DataFrame({"Id": pd.Series([], dtype=int), "Time": pd.Series([], dtype="datetime64[ns]")})
+ events = pd.DataFrame(
+ {
+ "Id": [1],
+ "Event_Time": [pd.Timestamp("2024-01-01")],
+ "Label": ["A"],
+ "~~reftime~~": [pd.Timestamp("2024-01-01")],
+ }
+ )
+ result = undertest._merge_event_counts(
+ preds, events, ["Id"], "MyEvent", "Label", window_hrs=1, l_ref="Time", r_ref="~~reftime~~"
+ )
+ assert len(result) == 0
+
+ def test_merge_event_counts_empty_right_returns_left(self):
+ """Empty right DataFrame should return left unchanged (no counts added)."""
+ preds = pd.DataFrame({"Id": [1], "Time": [pd.Timestamp("2024-01-01")]})
+ events = pd.DataFrame({"Id": pd.Series([], dtype=int), "Event_Time": [], "Label": [], "~~reftime~~": []})
+ result = undertest._merge_event_counts(
+ preds, events, ["Id"], "MyEvent", "Label", window_hrs=1, l_ref="Time", r_ref="~~reftime~~"
+ )
+ pdt.assert_frame_equal(result, preds)
+
+ def test_merge_event_counts_with_nan_labels(self, base_counts_data):
+ """NaN values in event_label should be handled (pandas treats as category)."""
+ preds, events = base_counts_data
+ events["~~reftime~~"] = events["Event_Time"]
+ # Don't set NaN - pandas might not handle NaN well in value_counts pivot
+ # Instead test that function works with various label values
+
+ result = undertest._merge_event_counts(
+ preds, events, ["Id"], "MyEvent", "Label", window_hrs=5, l_ref="Time", r_ref="~~reftime~~"
+ )
+ # Should work and have count columns
+ assert any("_Count" in col for col in result.columns)
+
+ def test_merge_event_counts_very_small_window(self):
+ """Very small window_hrs should have narrow window."""
+ preds = pd.DataFrame({"Id": [1], "Time": [pd.Timestamp("2024-01-01 01:00:00")]})
+ events = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Event_Time": [pd.Timestamp("2024-01-01 01:00:30"), pd.Timestamp("2024-01-01 03:00:00")],
+ "Label": ["A", "B"],
+ "~~reftime~~": [pd.Timestamp("2024-01-01 01:00:30"), pd.Timestamp("2024-01-01 03:00:00")],
+ }
+ )
+ result = undertest._merge_event_counts(
+ preds, events, ["Id"], "MyEvent", "Label", window_hrs=1, l_ref="Time", r_ref="~~reftime~~"
+ )
+ # With 1 hour window, event A should be included, B should not
+ assert "Label~A_Count" in result.columns
+ assert result["Label~A_Count"].iloc[0] == 1
+
+ def test_merge_event_counts_negative_min_offset(self):
+ """Negative min_offset allows looking into past - valid use case."""
+ preds = pd.DataFrame({"Id": [1], "Time": [pd.Timestamp("2024-01-01 12:00:00")]})
+ events = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Event_Time": [
+ pd.Timestamp("2024-01-01 10:00:00"), # 2 hours before pred
+ pd.Timestamp("2024-01-01 14:00:00"), # 2 hours after pred
+ ],
+ "Label": ["A", "B"],
+ "~~reftime~~": [
+ pd.Timestamp("2024-01-01 12:00:00"), # Adjusted by negative offset
+ pd.Timestamp("2024-01-01 16:00:00"),
+ ],
+ }
+ )
+
+ # Negative offset of -2 hours means we look 2 hours into the past
+ result = undertest._merge_event_counts(
+ preds,
+ events,
+ ["Id"],
+ "MyEvent",
+ "Label",
+ window_hrs=3,
+ min_offset=pd.Timedelta(hours=-2), # Negative: look into past
+ l_ref="Time",
+ r_ref="~~reftime~~",
+ )
+
+ # Both events should be counted with the negative offset
+ assert "Label~A_Count" in result.columns
+ assert result["Label~A_Count"].iloc[0] == 1
+
+ def test_merge_event_counts_large_min_offset(self):
+ """Large min_offset (larger than window) should work correctly."""
+ preds = pd.DataFrame({"Id": [1], "Time": [pd.Timestamp("2024-01-01 12:00:00")]})
+ events = pd.DataFrame(
+ {
+ "Id": [1],
+ "Event_Time": [pd.Timestamp("2024-01-01 20:00:00")], # 8 hours after
+ "Label": ["A"],
+ "~~reftime~~": [pd.Timestamp("2024-01-01 20:00:00")],
+ }
+ )
+
+ # min_offset of 5 hours with window of 2 hours
+ # Window is [pred+5hrs, pred+7hrs] = [17:00, 19:00]
+ # Event at 20:00 is outside window
+ result = undertest._merge_event_counts(
+ preds,
+ events,
+ ["Id"],
+ "MyEvent",
+ "Label",
+ window_hrs=2,
+ min_offset=pd.Timedelta(hours=5),
+ l_ref="Time",
+ r_ref="~~reftime~~",
+ )
+
+ # Event should not be counted (outside window)
+ if "Label~A_Count" in result.columns:
+ assert result["Label~A_Count"].iloc[0] == 0
+
class TestMergeWindowedEvent:
def test_basic_forward_strategy(self):
@@ -682,49 +1440,6 @@ def test_basic_forward_strategy(self):
assert result["MyEvent_Time"].iloc[1] == pd.Timestamp("2024-01-01 05:00:00")
assert result["MyEvent_Value"].iloc[1] == 1
- @pytest.mark.parametrize("strategy", ["forward", "nearest", "first", "last"])
- def test_merge_event_with_various_strategies(self, strategy):
- preds = pd.DataFrame(
- {
- "Id": [1, 1],
- "PredictTime": [
- pd.Timestamp("2024-01-01 00:00:00"),
- pd.Timestamp("2024-01-01 01:00:00"),
- ],
- }
- )
- events = pd.DataFrame(
- {
- "Id": [1, 1],
- "Time": [
- pd.Timestamp("2024-01-01 01:30:00"),
- pd.Timestamp("2024-01-01 02:00:00"),
- ],
- "Value": [1, 1],
- "Type": ["MyEvent", "MyEvent"],
- }
- )
-
- result = undertest.merge_windowed_event(
- preds,
- predtime_col="PredictTime",
- events=events,
- event_label="MyEvent",
- pks=["Id"],
- min_leadtime_hrs=1,
- window_hrs=2,
- event_base_val_col="Value",
- event_base_time_col="Time",
- merge_strategy=strategy,
- impute_val_with_time=1,
- impute_val_no_time=0,
- )
-
- # Result should include the matched event in _Value/_Time columns
- assert "MyEvent_Value" in result.columns
- assert "MyEvent_Time" in result.columns
- assert result["MyEvent_Value"].notna().all()
-
def test_merge_event_with_count_strategy(self):
preds = pd.DataFrame(
{
@@ -871,25 +1586,100 @@ def test_merge_with_strategy_info_logging(self, log_level, should_log, caplog):
matched = any("Added" in msg and "MyEvent" in msg for msg in info_logs)
assert matched == should_log
+ def test_merge_windowed_event_missing_predtime_col_raises(self):
+ """Missing predtime_col should raise KeyError."""
+ preds = pd.DataFrame({"Id": [1], "SomeOtherCol": [pd.Timestamp("2024-01-01")]})
+ events = pd.DataFrame({"Id": [1], "Time": [pd.Timestamp("2024-01-01")], "Value": [1], "Type": ["MyEvent"]})
-@patch("seismometer.data.pandas_helpers.try_casting")
-def test_post_process_event_skips_cast_when_dtype_none(mock_casting):
- df = pd.DataFrame({"Label": [None], "Time": [pd.Timestamp.now()]})
- result = undertest.post_process_event(df, "Label", "Time", column_dtype=None)
- assert "Label" in result.columns
- mock_casting.assert_not_called()
+ with pytest.raises(KeyError):
+ undertest.merge_windowed_event(
+ preds,
+ predtime_col="PredictTime", # Doesn't exist
+ events=events,
+ event_label="MyEvent",
+ pks=["Id"],
+ window_hrs=5,
+ event_base_val_col="Value",
+ event_base_time_col="Time",
+ )
+ def test_merge_windowed_event_missing_event_time_col_raises(self):
+ """Missing event_base_time_col should raise KeyError."""
+ preds = pd.DataFrame({"Id": [1], "PredictTime": [pd.Timestamp("2024-01-01")]})
+ events = pd.DataFrame({"Id": [1], "Value": [1], "Type": ["MyEvent"]}) # No Time column
-def test_one_event_filters_and_renames():
- events = pd.DataFrame(
- {
- "Id": [1, 1],
- "Type": ["Target", "Other"],
- "Value": [10, 20],
- "Time": [pd.Timestamp("2024-01-01"), pd.Timestamp("2024-01-01")],
- }
- )
- result = undertest._one_event(events, "Target", "Value", "Time", ["Id"])
- assert "Target_Value" in result.columns
- assert "Target_Time" in result.columns
- assert len(result) == 1
+ with pytest.raises(KeyError):
+ undertest.merge_windowed_event(
+ preds,
+ predtime_col="PredictTime",
+ events=events,
+ event_label="MyEvent",
+ pks=["Id"],
+ window_hrs=5,
+ event_base_val_col="Value",
+ event_base_time_col="Time", # Doesn't exist
+ )
+
+ def test_merge_windowed_event_invalid_event_label_returns_unchanged(self):
+ """Event label not in Type column should return predictions unchanged (early return)."""
+ preds = pd.DataFrame({"Id": [1], "PredictTime": [pd.Timestamp("2024-01-01")]})
+ events = pd.DataFrame(
+ {"Id": [1], "Time": [pd.Timestamp("2024-01-01")], "Value": [1], "Type": ["DifferentEvent"]}
+ )
+
+ result = undertest.merge_windowed_event(
+ preds,
+ predtime_col="PredictTime",
+ events=events,
+ event_label="MyEvent", # Not in Type column
+ pks=["Id"],
+ window_hrs=5,
+ event_base_val_col="Value",
+ event_base_time_col="Time",
+ )
+
+ # Should return predictions completely unchanged (early return when no events found)
+ assert len(result) == len(preds)
+ # No event columns added when event label doesn't exist
+ assert "MyEvent_Value" not in result.columns
+ assert "MyEvent_Time" not in result.columns
+ pdt.assert_frame_equal(result, preds)
+
+ def test_merge_windowed_event_with_sort_false(self):
+ """Test sort=False parameter - unsorted data should raise ValueError."""
+ # Create predictions and events in reverse chronological order (unsorted)
+ preds = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "PredictTime": [
+ pd.Timestamp("2024-01-01 04:00:00"), # Later time first
+ pd.Timestamp("2024-01-01 01:00:00"),
+ ],
+ }
+ )
+ events = pd.DataFrame(
+ {
+ "Id": [1, 1],
+ "Time": [
+ pd.Timestamp("2024-01-01 05:00:00"), # Later time first
+ pd.Timestamp("2024-01-01 02:00:00"),
+ ],
+ "Value": [2, 1],
+ "Type": ["MyEvent", "MyEvent"],
+ }
+ )
+
+ # merge_asof with unsorted data and sort=False should raise ValueError
+ with pytest.raises(ValueError):
+ undertest.merge_windowed_event(
+ preds,
+ predtime_col="PredictTime",
+ events=events,
+ event_label="MyEvent",
+ pks=["Id"],
+ window_hrs=5,
+ merge_strategy="forward",
+ event_base_val_col="Value",
+ event_base_time_col="Time",
+ sort=False, # Important: test unsorted merge raises error
+ )
diff --git a/tests/plot/test_utils.py b/tests/plot/test_utils.py
index 87134ec6..5a40edc1 100644
--- a/tests/plot/test_utils.py
+++ b/tests/plot/test_utils.py
@@ -1,6 +1,9 @@
from unittest.mock import Mock, patch
import matplotlib.pyplot as plt
+import pandas as pd
+import pytest
+from IPython.display import SVG
import seismometer.plot.mpl._util as utils
@@ -102,3 +105,363 @@ def test_clear_all(self):
axis.set_xlabel.assert_called_once_with(None)
axis.set_yticklabels.assert_called_once_with([])
axis.set_ylabel.assert_called_once_with(None)
+
+
+class TestToSvg:
+ """Test to_svg() function for SVG generation"""
+
+ @patch.object(plt, "savefig")
+ def test_to_svg_returns_svg_object(self, save_mock):
+ """Test to_svg() returns an SVG object"""
+
+ # Mock savefig to write valid SVG content to buffer
+ def write_svg(buffer, **kwargs):
+ buffer.write('')
+
+ save_mock.side_effect = write_svg
+ result = utils.to_svg()
+ assert isinstance(result, SVG)
+ save_mock.assert_called_once()
+
+ @patch.object(plt, "savefig")
+ def test_to_svg_calls_savefig_with_svg_format(self, save_mock):
+ """Test to_svg() calls savefig with format='svg'"""
+
+ def write_svg(buffer, **kwargs):
+ buffer.write('')
+
+ save_mock.side_effect = write_svg
+ utils.to_svg()
+ # Check that savefig was called with format='svg'
+ call_kwargs = save_mock.call_args[1]
+ assert call_kwargs.get("format") == "svg"
+
+ @patch.object(plt, "savefig")
+ def test_to_svg_with_empty_plot(self, save_mock):
+ """Test to_svg() with empty plot doesn't crash"""
+
+ def write_svg(buffer, **kwargs):
+ buffer.write('')
+
+ save_mock.side_effect = write_svg
+ plt.figure()
+ result = utils.to_svg()
+ assert isinstance(result, SVG)
+ plt.close()
+
+
+class TestCreateCheckboxes:
+ """Test create_checkboxes() function for widget creation"""
+
+ def test_create_checkboxes_returns_list(self):
+ """Test create_checkboxes() returns a list"""
+ result = utils.create_checkboxes(["A", "B", "C"])
+ assert isinstance(result, list)
+ assert len(result) == 3
+
+ def test_create_checkboxes_widget_properties(self):
+ """Test create_checkboxes() creates widgets with correct properties"""
+ values = ["Option1", "Option2"]
+ checkboxes = utils.create_checkboxes(values)
+
+ for checkbox, value in zip(checkboxes, values):
+ assert checkbox.description == value
+ assert checkbox.value is True # Default value
+
+ def test_create_checkboxes_with_numeric_values(self):
+ """Test create_checkboxes() converts numeric values to strings"""
+ values = [1, 2, 3]
+ checkboxes = utils.create_checkboxes(values)
+
+ for checkbox, value in zip(checkboxes, values):
+ assert checkbox.description == str(value)
+
+ def test_create_checkboxes_empty_list(self):
+ """Test create_checkboxes() with empty list"""
+ result = utils.create_checkboxes([])
+ assert result == []
+
+ def test_create_checkboxes_with_mixed_types(self):
+ """Test create_checkboxes() with mixed types in list"""
+ values = [1, "A", 2.5, True]
+ checkboxes = utils.create_checkboxes(values)
+ assert len(checkboxes) == 4
+ assert checkboxes[0].description == "1"
+ assert checkboxes[1].description == "A"
+ assert checkboxes[2].description == "2.5"
+ assert checkboxes[3].description == "True"
+
+
+class TestAddUnseen:
+ """Test add_unseen() function for categorical data handling"""
+
+ def test_add_unseen_with_missing_categories(self):
+ """Test add_unseen() adds missing categorical values"""
+ df = pd.DataFrame({"cohort": pd.Categorical(["A", "A"], categories=["A", "B", "C"])})
+
+ result = utils.add_unseen(df, col="cohort")
+
+ # Should have 2 original rows + 2 unseen categories
+ assert len(result) == 4
+ assert set(result["cohort"].dropna()) == {"A", "B", "C"}
+
+ def test_add_unseen_preserves_categorical_dtype(self):
+ """Test add_unseen() preserves categorical dtype and categories"""
+ original_cats = ["A", "B", "C"]
+ df = pd.DataFrame({"cohort": pd.Categorical(["A"], categories=original_cats)})
+
+ result = utils.add_unseen(df, col="cohort")
+
+ assert result["cohort"].dtype.name == "category"
+ assert list(result["cohort"].cat.categories) == original_cats
+
+ def test_add_unseen_with_all_categories_present(self):
+ """Test add_unseen() when all categories are already present"""
+ df = pd.DataFrame({"cohort": pd.Categorical(["A", "B", "C"], categories=["A", "B", "C"])})
+
+ result = utils.add_unseen(df, col="cohort")
+
+ # Should only have original rows
+ assert len(result) == 3
+
+ def test_add_unseen_with_custom_column_name(self):
+ """Test add_unseen() with custom column name"""
+ df = pd.DataFrame({"feature": pd.Categorical(["X"], categories=["X", "Y", "Z"])})
+
+ result = utils.add_unseen(df, col="feature")
+
+ assert len(result) == 3
+ assert set(result["feature"].dropna()) == {"X", "Y", "Z"}
+
+ def test_add_unseen_preserves_other_columns(self):
+ """Test add_unseen() preserves other columns (fills with NaN)"""
+ df = pd.DataFrame(
+ {"cohort": pd.Categorical(["A", "A"], categories=["A", "B"]), "value": [10, 20], "name": ["x", "y"]}
+ )
+
+ result = utils.add_unseen(df, col="cohort")
+
+ # Original rows should have values, new rows should have NaN
+ assert result.iloc[0]["value"] == 10
+ assert result.iloc[1]["value"] == 20
+ assert pd.isna(result.iloc[2]["value"])
+
+
+class TestNeededColors:
+ """Test needed_colors() function for color mapping"""
+
+ def test_needed_colors_with_categorical_series(self):
+ """Test needed_colors() with categorical series"""
+ series = pd.Series(pd.Categorical(["A", "A", "C"], categories=["A", "B", "C"]))
+ colors = ["red", "green", "blue"]
+
+ result = utils.needed_colors(series, colors)
+
+ # Should return colors for observed categories (A=0, C=2)
+ assert result == ["red", "blue"]
+
+ def test_needed_colors_with_all_categories_observed(self):
+ """Test needed_colors() when all categories are observed"""
+ series = pd.Series(pd.Categorical(["A", "B", "C"], categories=["A", "B", "C"]))
+ colors = ["red", "green", "blue"]
+
+ result = utils.needed_colors(series, colors)
+
+ assert result == ["red", "green", "blue"]
+
+ def test_needed_colors_with_non_categorical_series(self):
+ """Test needed_colors() with non-categorical series (fallback behavior)"""
+ series = pd.Series(["A", "B", "A", "C"])
+ colors = ["red", "green", "blue"]
+
+ result = utils.needed_colors(series, colors)
+
+ # Should return colors based on number of unique values
+ assert len(result) == 3 # 3 unique values
+
+ def test_needed_colors_with_numeric_series(self):
+ """Test needed_colors() with numeric series"""
+ series = pd.Series([1, 2, 1, 3, 2])
+ colors = ["red", "green", "blue"]
+
+ result = utils.needed_colors(series, colors)
+
+ # Should handle numeric series (3 unique values)
+ assert len(result) == 3
+
+ def test_needed_colors_single_category(self):
+ """Test needed_colors() with single category"""
+ series = pd.Series(pd.Categorical(["A", "A", "A"], categories=["A", "B"]))
+ colors = ["red", "green"]
+
+ result = utils.needed_colors(series, colors)
+
+ assert result == ["red"]
+
+
+class TestPlotCurveEdgeCases:
+ """Test plot_curve() edge cases and error handling"""
+
+ def test_plot_curve_with_empty_line(self):
+ """Test plot_curve() with empty line data"""
+ axis = Mock()
+ utils.plot_curve(axis, [[], []])
+ axis.plot.assert_called_once_with([], [], label=None)
+
+ def test_plot_curve_with_single_point(self):
+ """Test plot_curve() with single point"""
+ axis = Mock()
+ utils.plot_curve(axis, [[1], [2]])
+ axis.plot.assert_called_once_with([1], [2], label=None)
+
+ def test_plot_curve_with_empty_accent_dict(self):
+ """Test plot_curve() with empty accent_dict still calls legend"""
+ axis = Mock()
+ utils.plot_curve(axis, [[1, 2], [3, 4]], accent_dict={})
+ # Legend is called even with empty accent_dict (not None)
+ axis.legend.assert_called_once()
+
+ def test_plot_curve_with_none_accent_dict(self):
+ """Test plot_curve() with None accent_dict doesn't call legend"""
+ axis = Mock()
+ utils.plot_curve(axis, [[1, 2], [3, 4]], accent_dict=None)
+ # Should not call legend when accent_dict is None
+ axis.legend.assert_not_called()
+
+ def test_plot_curve_with_malformed_accent_dict_missing_y(self):
+ """Test plot_curve() with malformed accent_dict (missing y values)"""
+ axis = Mock()
+ with pytest.raises((IndexError, KeyError)):
+ utils.plot_curve(axis, [[1, 2], [3, 4]], accent_dict={"accent": [[1]]})
+
+ def test_plot_curve_with_none_values_in_accents(self):
+ """Test plot_curve() handles None in accent coordinates"""
+ axis = Mock()
+ # This should call plot but may fail at matplotlib level
+ utils.plot_curve(axis, [[1, 2], [3, 4]], accent_dict={"accent": [[None], [None]]})
+ axis.plot.assert_any_call([None], [None], "x", label="accent")
+
+
+class TestCohortLegend:
+ """Test cohort_legend() function for legend creation"""
+
+ def test_cohort_legend_with_true_column(self):
+ """Test cohort_legend() with 'true' column format"""
+ fig, axes = plt.subplots(1, 2)
+
+ # Plot some lines on first axis
+ axes[0].plot([0, 1], [0, 1], label="Cohort A")
+ axes[0].plot([0, 1], [0, 0.5], label="Cohort B")
+
+ # Create data with 'true' column
+ data = pd.DataFrame(
+ {
+ "cohort": pd.Categorical(["A", "A", "B", "B"], categories=["A", "B"]),
+ "true": [1, 0, 1, 1],
+ "pred": [0.8, 0.2, 0.9, 0.7],
+ }
+ )
+
+ # Call cohort_legend on second axis
+ utils.cohort_legend(data, axes[1], "Test Feature", ref_axis=0)
+
+ # Verify legend was created
+ assert axes[1].get_legend() is not None
+
+ plt.close(fig)
+
+ def test_cohort_legend_with_cohort_count_column(self):
+ """Test cohort_legend() with 'cohort-count' column format"""
+ fig, axes = plt.subplots(1, 2)
+
+ # Plot a line on first axis
+ axes[0].plot([0, 1], [0, 1], label="Cohort A")
+
+ # Create data with cohort-count format
+ data = pd.DataFrame(
+ {
+ "cohort": pd.Categorical(["A"], categories=["A", "B"]),
+ "cohort-count": [100],
+ "cohort-targetcount": [25],
+ }
+ )
+
+ # Call cohort_legend on second axis
+ utils.cohort_legend(data, axes[1], "Test Feature", ref_axis=0)
+
+ assert axes[1].get_legend() is not None
+ plt.close(fig)
+
+ def test_cohort_legend_below_censor_threshold(self):
+ """Test cohort_legend() censors data below threshold"""
+ fig, axes = plt.subplots(1, 2)
+
+ # Plot a line on first axis
+ axes[0].plot([0, 1], [0, 1], label="Cohort A")
+
+ # Create small data (below default threshold of 10)
+ data = pd.DataFrame(
+ {"cohort": pd.Categorical(["A"] * 5, categories=["A"]), "true": [1, 0, 1, 1, 0], "pred": [0.8] * 5}
+ )
+
+ # Call cohort_legend with default censor_threshold=10
+ utils.cohort_legend(data, axes[1], "Test Feature", ref_axis=0)
+
+ # Should still create legend but with censored values
+ assert axes[1].get_legend() is not None
+ plt.close(fig)
+
+ def test_cohort_legend_more_lines_than_cohorts_error(self):
+ """Test cohort_legend() raises IndexError when more lines than cohorts"""
+ fig, axes = plt.subplots(1, 2)
+
+ # Plot more lines than we have cohorts
+ axes[0].plot([0, 1], [0, 1], label="Line 1")
+ axes[0].plot([0, 1], [0, 0.5], label="Line 2")
+ axes[0].plot([0, 1], [0, 0.3], label="Line 3")
+
+ # Create data with only 2 cohorts
+ data = pd.DataFrame({"cohort": pd.Categorical(["A", "B"], categories=["A", "B"]), "true": [1, 0]})
+
+ # Should raise IndexError
+ with pytest.raises(IndexError, match="More lines than cohorts"):
+ utils.cohort_legend(data, axes[1], "Test Feature", ref_axis=0)
+
+ plt.close(fig)
+
+ def test_cohort_legend_with_custom_labellist(self):
+ """Test cohort_legend() with custom label list"""
+ fig, axes = plt.subplots(1, 2)
+
+ # Plot a line
+ axes[0].plot([0, 1], [0, 1])
+
+ # Create data (matching array lengths)
+ data = pd.DataFrame(
+ {"cohort": pd.Categorical(["A", "A"], categories=["A"]), "true": [1, 0], "pred": [0.8, 0.2]}
+ )
+
+ # Call with custom labels
+ utils.cohort_legend(data, axes[1], "Test Feature", labellist=["Custom Label"], ref_axis=0)
+
+ assert axes[1].get_legend() is not None
+ plt.close(fig)
+
+ def test_cohort_legend_skips_dashed_lines(self):
+ """Test cohort_legend() skips dashed reference lines"""
+ fig, axes = plt.subplots(1, 2)
+
+ # Plot solid and dashed lines
+ axes[0].plot([0, 1], [0, 1], label="Cohort A") # Solid
+ axes[0].plot([0, 1], [0.5, 0.5], "--", label="Reference") # Dashed (should be skipped)
+
+ # Create data with one cohort (matching array lengths)
+ data = pd.DataFrame(
+ {"cohort": pd.Categorical(["A", "A"], categories=["A"]), "true": [1, 0], "pred": [0.8, 0.2]}
+ )
+
+ # Should only use the solid line (not dashed)
+ utils.cohort_legend(data, axes[1], "Test Feature", ref_axis=0)
+
+ assert axes[1].get_legend() is not None
+ plt.close(fig)