Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/185.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added unit tests for data merging and event processing utilities.
28 changes: 18 additions & 10 deletions src/seismometer/data/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,34 +151,38 @@ def get_cohort_performance_data(
return frame


def resolve_col_data(df: pd.DataFrame, feature: Union[str, pd.Series]) -> pd.Series:
def resolve_col_data(df: pd.DataFrame, feature: Union[str, SeriesOrArray]) -> pd.Series:
"""
Handles resolving feature from either being a series or specifying a series in the dataframe.
Handles resolving feature from either being a series, array, or specifying a series in the dataframe.

Parameters
----------
df : pd.DataFrame
Containing a column of name feature if feature is passed in as a string.
feature : Union[str, pd.Series]
Either a pandas.Series or a column name in the dataframe.
feature : Union[str, SeriesOrArray]
Either a pandas.Series, numpy array, or a column name in the dataframe.

Returns
-------
pd.Series.
pd.Series
Always returns a pandas Series, with index matching df.index for array inputs.
"""

if isinstance(feature, str):
if feature in df.columns:
return df[feature].copy()
else:
raise KeyError(f"Feature {feature} was not found in dataframe")
elif isinstance(feature, pd.Series):
return feature # Already a Series, preserve its index
elif hasattr(feature, "ndim"):
# Convert arrays to Series with df's index for proper alignment
if feature.ndim > 1: # probas from sklearn is nx2 with second column being the positive predictions
return feature[:, 1]
return pd.Series(feature[:, 1], index=df.index)
else:
return feature
return pd.Series(feature, index=df.index)
else:
raise TypeError("Feature must be a string or pandas.Series, was given a ", type(feature))
raise TypeError(f"Feature must be a string, pandas.Series, or numpy.ndarray, was given {type(feature)}")


# endregion
Expand Down Expand Up @@ -232,7 +236,11 @@ def label_cohorts_numeric(series: SeriesOrArray, splits: Optional[List] = None)
labels = [f"{bins[i]}-{bins[i+1]}" for i in range(len(bins) - 1)] + [f">={bins[-1]}"]
labels[0] = f"<{bins[1]}"
cat = pd.Categorical.from_codes(bin_ixs - 1, labels)
return pd.Series(cat)
# Preserve the input series index for proper alignment in pd.concat
if isinstance(series, pd.Series):
return pd.Series(cat, index=series.index)
else:
return pd.Series(cat)


def has_good_binning(bin_ixs: List, bin_edges: List) -> None:
Expand Down Expand Up @@ -275,7 +283,7 @@ def label_cohorts_categorical(series: SeriesOrArray, cat_values: Optional[list]
List of string labels for each bin; which is the list of categories.
"""
series.name = "cohort"
series.cat._name = "cohort" # CategoricalAccessors have a different name..
series.cat._name = "cohort" # CategoricalAccessors have a different name.

# If no splits specified, restrict to observed values
if cat_values is None:
Expand Down
4 changes: 2 additions & 2 deletions src/seismometer/data/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def __str__(self) -> str:
case "notna":
return f"{self.left} has a value"
case "isin":
return f"{self.left} is in: {', '.join(self.right)}"
return f"{self.left} is in: {', '.join(map(str, self.right))}"
case "notin":
return f"{self.left} not in: {', '.join(self.right)}"
return f"{self.left} not in: {', '.join(map(str, self.right))}"
case "topk":
return f"{self.left} in top {self.right} values"
case "nottopk":
Expand Down
19 changes: 11 additions & 8 deletions src/seismometer/data/pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def post_process_event(
# cast after imputation - supports nonnullable types
try_casting(dataframe, label_col, column_dtype)

# Log how many rows were imputed/changed
imputed_with_time = ((label_na_map & ~time_na_map) & dataframe[label_col].notna()).sum()
imputed_no_time = (dataframe[label_col].isna()).sum()
# Log how many rows were imputed
imputed_with_time = (label_na_map & ~time_na_map).sum()
imputed_no_time = (label_na_map & time_na_map).sum()
logger.debug(
f"Post-processing of events for {label_col} and {time_col} complete. "
f"Imputed {imputed_with_time} rows with time, {imputed_no_time} rows with no time."
Expand Down Expand Up @@ -443,13 +443,15 @@ def _merge_with_strategy(
if merge_strategy == "first":
logger.debug(f"Updating events to only keep the first occurrence for each {event_display}.")
one_event_filtered = one_event.groupby(pks).first().reset_index()
if merge_strategy == "last":
elif merge_strategy == "last":
logger.debug(f"Updating events to only keep the last occurrence for each {event_display}.")
one_event_filtered = one_event.groupby(pks).last().reset_index()

except ValueError as e:
logger.warning(e)
pass
# Only continue with fallback merge if one_event_filtered was set
if "one_event_filtered" not in locals():
raise

return pd.merge(predictions, one_event_filtered, on=pks, how="left")

Expand Down Expand Up @@ -778,14 +780,15 @@ def _resolve_score_col(dataframe: pd.DataFrame, score: str) -> str:


def analytics_metric_name(metric_names: list[str], existing_metric_starts: list[str], column_name: str) -> str:
"""In the analytics table, often the provided column name is not the actual
"""
In the analytics table, often the provided column name is not the actual
metric name that we want to log. Here, we extract the desired metric name.

Parameters
----------
metric_names : list[str]
What metrics already exist.
existing_metric_values : list[str]
existing_metric_starts : list[str]
What strings can start the mangled column name.
column_name : str
The name of the column we are trying to make into a metric.
Expand All @@ -800,7 +803,7 @@ def analytics_metric_name(metric_names: list[str], existing_metric_starts: list[
else:
for value in existing_metric_starts:
if column_name.startswith(f"{value}_"):
return column_name.lstrip(f"{value}_")
return column_name.removeprefix(f"{value}_")
return None


Expand Down
4 changes: 4 additions & 0 deletions src/seismometer/plot/mpl/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def add_unseen(df: pd.DataFrame, col="cohort") -> pd.DataFrame:
obs = df[col].unique()
unseen = [k for k in keys if k not in obs]

# Only concatenate if there are unseen categories
if not unseen:
return df

rv = pd.concat([df, pd.DataFrame({col: unseen})], ignore_index=True)
rv[col] = rv[col].astype(pd.CategoricalDtype(df[col].cat.categories))
return rv
Expand Down
210 changes: 209 additions & 1 deletion tests/api/test_api_explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@
from ipywidgets import HTML as WidgetHTML

from seismometer import Seismogram
from seismometer.api.explore import ExplorationWidget, ExploreBinaryModelMetrics, cohort_list_details
from seismometer.api.explore import (
ExplorationWidget,
ExploreBinaryModelMetrics,
ExploreCohortEvaluation,
ExploreCohortHistograms,
ExploreCohortLeadTime,
ExploreCohortOutcomeInterventionTimes,
ExploreModelEvaluation,
ExploreModelScoreComparison,
ExploreModelTargetComparison,
ExploreSubgroups,
cohort_list_details,
)
from seismometer.api.plots import (
_model_evaluation,
_plot_cohort_evaluation,
Expand All @@ -18,6 +30,7 @@
plot_cohort_lead_time,
plot_intervention_outcome_timeseries,
plot_leadtime_enc,
plot_model_evaluation,
plot_model_score_comparison,
plot_model_target_comparison,
plot_trend_intervention_outcome,
Expand Down Expand Up @@ -770,3 +783,198 @@ def generate_plot_args(self):
assert isinstance(result, WidgetHTML)
assert "Traceback" in result.value
assert "kaboom" in result.value


# ============================================================================
# WIDGET INITIALIZATION TESTS
# ============================================================================


class TestWidgetClassInitialization:
"""Test that all widget classes can be initialized successfully."""

def test_explore_subgroups_initialization(self, fake_seismo):
"""Test ExploreSubgroups widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreSubgroups()
assert widget is not None
assert hasattr(widget, "plot_function")
assert widget.plot_function == cohort_list_details

def test_explore_model_evaluation_initialization(self, fake_seismo):
"""Test ExploreModelEvaluation widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreModelEvaluation()
assert widget is not None
assert hasattr(widget, "plot_function")
assert widget.plot_function == plot_model_evaluation

def test_explore_model_score_comparison_initialization(self, fake_seismo):
"""Test ExploreModelScoreComparison widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreModelScoreComparison()
assert widget is not None
assert hasattr(widget, "plot_function")
assert widget.plot_function == plot_model_score_comparison

def test_explore_model_target_comparison_initialization(self, fake_seismo):
"""Test ExploreModelTargetComparison widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreModelTargetComparison()
assert widget is not None
assert hasattr(widget, "plot_function")
assert widget.plot_function == plot_model_target_comparison

def test_explore_cohort_evaluation_initialization(self, fake_seismo):
"""Test ExploreCohortEvaluation widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreCohortEvaluation()
assert widget is not None
assert hasattr(widget, "plot_function")
# Note: plot_function is wrapped by the parent class

def test_explore_cohort_histograms_initialization(self, fake_seismo):
"""Test ExploreCohortHistograms widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreCohortHistograms()
assert widget is not None
assert hasattr(widget, "plot_function")

def test_explore_cohort_lead_time_initialization(self, fake_seismo):
"""Test ExploreCohortLeadTime widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreCohortLeadTime()
assert widget is not None
assert hasattr(widget, "plot_function")

def test_explore_cohort_outcome_intervention_times_initialization(self, fake_seismo):
"""Test ExploreCohortOutcomeInterventionTimes widget initialization."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreCohortOutcomeInterventionTimes()
assert widget is not None
assert hasattr(widget, "plot_function")
assert widget.plot_function == plot_intervention_outcome_timeseries

@pytest.mark.parametrize("rho,expected_rho", [(None, 1 / 3), (0.0, 1 / 3), (0.5, 0.5), (1.0, 1.0)])
def test_explore_binary_model_metrics_initialization_with_rho(self, fake_seismo, rho, expected_rho):
"""Test ExploreBinaryModelMetrics widget initialization with different rho values."""
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
widget = ExploreBinaryModelMetrics(rho=rho)
assert widget is not None
assert hasattr(widget, "plot_function")
assert hasattr(widget, "metric_generator")
assert isinstance(widget.metric_generator, BinaryClassifierMetricGenerator)
# Verify rho is set correctly (0.0 and None use default 1/3)
assert abs(widget.metric_generator.rho - expected_rho) < 1e-10


# ============================================================================
# cohort_list() EDGE CASES
# ============================================================================


class TestCohortListFunction:
"""Test edge cases and error handling for cohort_list_details function.

Note: cohort_list() widget creation is already tested in TestExploreSubgroups.test_cohort_list_widget_rendering
"""

@patch("seismometer.html.template.render_title_message", return_value=HTML("summary"))
@patch("seismometer.data.filter.filter_rule_from_cohort_dictionary")
@patch("seismometer.seismogram.Seismogram")
def test_cohort_list_details_with_empty_cohort_dict(self, mock_seismo, mock_filter, mock_render, fake_seismo):
"""Test cohort_list_details with empty cohort dictionary."""
mock_seismo.return_value = fake_seismo

# Empty dictionary should use all data (no filtering)
rule = MagicMock()
rule.filter.return_value = fake_seismo.dataframe
mock_filter.return_value = rule

result = cohort_list_details({})

assert isinstance(result, HTML)
mock_filter.assert_called_once_with({})

@patch("seismometer.data.filter.filter_rule_from_cohort_dictionary")
@patch("seismometer.seismogram.Seismogram")
def test_cohort_list_details_exception_handling(self, mock_seismo, mock_filter, fake_seismo):
"""Test cohort_list_details exception handling when filtering fails."""
mock_seismo.return_value = fake_seismo

# Simulate filter failure
mock_filter.side_effect = KeyError("Invalid cohort column")

with pytest.raises(KeyError, match="Invalid cohort column"):
cohort_list_details({"InvalidColumn": ["Value"]})

@patch("seismometer.html.template.render_censored_plot_message", return_value=HTML("censored"))
@patch("seismometer.data.filter.filter_rule_from_cohort_dictionary")
@patch("seismometer.seismogram.Seismogram")
def test_cohort_list_details_below_censor_threshold(self, mock_seismo, mock_filter, mock_render, fake_seismo):
"""Test cohort_list_details when data is below censor threshold."""
fake_seismo.config.censor_min_count = 100
mock_seismo.return_value = fake_seismo

rule = MagicMock()
rule.filter.return_value = fake_seismo.dataframe
mock_filter.return_value = rule

result = cohort_list_details({"Cohort": ["C1"]})

assert "censored" in result.data.lower()
mock_render.assert_called_once()

@patch("seismometer.html.template.render_title_message", return_value=HTML("summary"))
@patch("seismometer.data.filter.filter_rule_from_cohort_dictionary")
@patch("seismometer.seismogram.Seismogram")
def test_cohort_list_details_with_no_context_id(self, mock_seismo, mock_filter, mock_render, fake_seismo):
"""Test cohort_list_details when context_id is None."""
fake_seismo.config.context_id = None
mock_seismo.return_value = fake_seismo

rule = MagicMock()
rule.filter.return_value = fake_seismo.dataframe
mock_filter.return_value = rule

result = cohort_list_details({"Cohort": ["C1", "C2"]})

assert isinstance(result, HTML)
mock_render.assert_called_once()

@patch("seismometer.html.template.render_title_message", return_value=HTML("summary"))
@patch("seismometer.data.filter.filter_rule_from_cohort_dictionary")
@patch("seismometer.seismogram.Seismogram")
def test_cohort_list_details_with_multiple_targets(self, mock_seismo, mock_filter, mock_render, fake_seismo):
"""Test cohort_list_details with multiple target events."""
fake_seismo.config.targets = ["event1", "event2", "event3"]
mock_seismo.return_value = fake_seismo

rule = MagicMock()
rule.filter.return_value = fake_seismo.dataframe
mock_filter.return_value = rule

result = cohort_list_details({"Cohort": ["C1", "C2"]})

assert isinstance(result, HTML)
assert "summary" in result.data

@patch("seismometer.html.template.render_title_message", return_value=HTML("summary"))
@patch("seismometer.data.filter.filter_rule_from_cohort_dictionary")
@patch("seismometer.seismogram.Seismogram")
def test_cohort_list_details_with_no_interventions_or_outcomes(
self, mock_seismo, mock_filter, mock_render, fake_seismo
):
"""Test cohort_list_details when interventions/outcomes are empty."""
fake_seismo.config.interventions = {}
fake_seismo.config.outcomes = {}
mock_seismo.return_value = fake_seismo

rule = MagicMock()
rule.filter.return_value = fake_seismo.dataframe
mock_filter.return_value = rule

result = cohort_list_details({"Cohort": ["C1"]})

assert isinstance(result, HTML)
assert "summary" in result.data
Loading
Loading