Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ docs/example_notebooks/notebooks

# virtualenv
venv/
.venv/
ENV/

# IDE settings
Expand Down
1 change: 1 addition & 0 deletions changelog/143.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated per-context visualizations to label combined-score results as “Per Context” (instead of “Per Encounter”) and clarified how “combine scores” aggregates across (entity_id, context_id); cohort summaries now optionally include a “Contexts” count when a context_id is configured.
14 changes: 7 additions & 7 deletions src/seismometer/api/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ def _plot_leadtime_enc(
target_event: str,
target_zero: str,
score: str,
threshold: list[float],
threshold: float,
ref_time: str,
cohort_col: str,
subgroups: list[any],
subgroups: list[Any],
max_hours: int,
x_label: str,
censor_threshold: int = 10,
Expand All @@ -240,7 +240,7 @@ def _plot_leadtime_enc(
event column
target_zero : str
event value
threshold : str
threshold : float
score thresholds
score : str
score column
Expand All @@ -250,7 +250,7 @@ def _plot_leadtime_enc(
entity key column
cohort_col : str
cohort column name
subgroups : list[any]
subgroups : list[Any]
cohort groups from the cohort column
x_label : str
label for the x axis of the plot
Expand Down Expand Up @@ -540,7 +540,7 @@ def _model_evaluation(
censor_threshold : int, optional
minimum rows to allow in a plot, by default 10
per_context_id : bool, optional
report only the max score for a given entity context, by default False
If True, aggregate scores per (entity_id, context_id) context, by default False
aggregation_method : str, optional
method to reduce multiple scores into a single value before calculation of performance, by default "max"
ignored if per_context_id is False
Expand Down Expand Up @@ -596,7 +596,7 @@ def _model_evaluation(
attributes=params | cohort,
metrics={metric: stats[[metric, "Threshold"]].set_index("Threshold").to_dict()},
)
title = f"Overall Performance for {target_event} (Per {'Encounter' if per_context_id else 'Observation'})"
title = f"Overall Performance for {target_event} (Per {'Context' if per_context_id else 'Observation'})"
svg = plot.evaluation(
stats,
ci_data=ci_data,
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def binary_classifier_metric_evaluation(
censor_threshold : int, optional
minimum rows to allow in a plot, by default 10
per_context_id : bool, optional
report only the max score for a given entity context, by default False
If True, combine scores per (entity_id, context_id) as defined in usage_config.yml
aggregation_method : str, optional
method to reduce multiple scores into a single value before calculation of performance, by default "max"
ignored if per_context_id is False
Expand Down
7 changes: 5 additions & 2 deletions src/seismometer/api/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,24 @@ def _get_cohort_summary_dataframes(by_target: bool, by_score: bool) -> dict[str,
The dictionary, indexed by cohort attribute (e.g. Race), of summary dataframes.
"""
sg = Seismogram()
context_id = getattr(sg.config, "context_id", None)

dfs: dict[str, list[str]] = {}

available_cohort_groups = sg.available_cohort_groups

for attribute, options in available_cohort_groups.items():
df = default_cohort_summaries(sg.dataframe, attribute, options, sg.config.entity_id)
df = default_cohort_summaries(sg.dataframe, attribute, options, sg.config.entity_id, context_id)
styled = _style_cohort_summaries(df, attribute)

dfs[attribute] = [styled.to_html()]

if by_score or by_target:
groupby_groups, grab_groups, index_rename = _score_target_levels_and_index(attribute, by_target, by_score)

results = score_target_cohort_summaries(sg.dataframe, groupby_groups, grab_groups, sg.config.entity_id)
results = score_target_cohort_summaries(
sg.dataframe, groupby_groups, grab_groups, sg.config.entity_id, context_id
)
results_styled = _style_score_target_cohort_summaries(results, index_rename, attribute)

dfs[attribute].append(results_styled.to_html())
Expand Down
38 changes: 35 additions & 3 deletions src/seismometer/data/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@export
def default_cohort_summaries(
dataframe: pd.DataFrame, attribute: str, options: list[str], entity_id_col: str
dataframe: pd.DataFrame, attribute: str, options: list[str], entity_id_col: str, context_id_col: str | None = None
) -> pd.DataFrame:
"""
Generate a dataframe of summary counts from the input dataframe.
Expand Down Expand Up @@ -40,7 +40,23 @@ def default_cohort_summaries(
.rename("Entities")
)

return pd.concat([left, right], axis=1).reindex(options)
pieces = [left, right]
if context_id_col and context_id_col in dataframe.columns:
contexts = (
pdh.event_score(
dataframe,
[entity_id_col, context_id_col],
sg.output,
sg.predict_time,
sg.target,
sg.event_aggregation_method(sg.target),
)[attribute]
.value_counts()
.rename("Contexts")
)
pieces.append(contexts)

return pd.concat(pieces, axis=1).reindex(options)


@export
Expand All @@ -49,6 +65,7 @@ def score_target_cohort_summaries(
groupby_groups: list[str],
grab_groups: list[str],
entity_id_col: str,
context_id_col: str | None = None,
) -> pd.DataFrame:
"""
Generate a dataframe of summary counts from the input dataframe.
Expand Down Expand Up @@ -79,4 +96,19 @@ def score_target_cohort_summaries(
)
entities = df[grab_groups].groupby(groupby_groups, observed=False).size().rename("Entities").astype("Int64")

return pd.DataFrame(pd.concat([predictions, entities], axis=1)).fillna(0)
pieces = [predictions, entities]
if context_id_col and context_id_col in dataframe.columns:
ctx_df = pdh.event_score(
dataframe,
[entity_id_col, context_id_col],
sg.output,
sg.predict_time,
sg.target,
sg.event_aggregation_method(sg.target),
)
contexts = (
ctx_df[grab_groups].groupby(groupby_groups, observed=False).size().rename("Contexts").astype("Int64")
)
pieces.append(contexts)

return pd.DataFrame(pd.concat(pieces, axis=1)).fillna(0)
55 changes: 55 additions & 0 deletions tests/data/test_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,58 @@ def test_event_score_match_score_target_summaries(

# Ensuring they produce the same number of entities for each score-target-cohort group
assert entities_event_score.tolist() == entities_summary.tolist()

@patch.object(seismogram, "Seismogram", return_value=Mock())
def test_default_summaries_adds_contexts_when_context_id_col_provided(self, mock_seismo, prediction_data):
fake_seismo = mock_seismo()
fake_seismo.output = "Score"
fake_seismo.target = "Target"
fake_seismo.predict_time = "Target"
fake_seismo.event_aggregation_method = lambda x: "max"

df = prediction_data.copy()

# create multiple contexts per ID so Contexts >= Entities can actually happen
df["Context"] = df.groupby("ID").cumcount() % 2

actual = undertest.default_cohort_summaries(df, "Has_ECG", [1, 2, 3, 4, 5], "ID", context_id_col="Context")

assert "Contexts" in actual.columns

ctx = (
event_score(df, ["ID", "Context"], "Score", "Target", "Target", "max")["Has_ECG"]
.value_counts()
.rename("Contexts")
)
pd.testing.assert_series_equal(
actual["Contexts"].dropna(),
ctx.reindex([1, 2, 3, 4, 5]).dropna(),
check_names=False,
)

# contexts >= entities (where both are present)
both = actual[["Entities", "Contexts"]].dropna()
assert (both["Contexts"] >= both["Entities"]).all()

@patch.object(seismogram, "Seismogram", return_value=Mock())
def test_score_target_summaries_adds_contexts_when_context_id_col_provided(
self, mock_seismo, prediction_data, expected_score_target_summary_cuts
):
fake_seismo = mock_seismo()
fake_seismo.output = "Score"
fake_seismo.target = "Target"
fake_seismo.predict_time = "Target"
fake_seismo.event_aggregation_method = lambda x: "max"

df = prediction_data.copy()
df["Context"] = df.groupby("ID").cumcount() % 2

groupby_groups = ["Has_ECG", expected_score_target_summary_cuts]
grab_groups = ["Has_ECG", "Score"]

actual = undertest.score_target_cohort_summaries(
df, groupby_groups, grab_groups, "ID", context_id_col="Context"
)

assert "Contexts" in actual.columns
assert (actual["Contexts"] >= actual["Entities"]).all()
49 changes: 49 additions & 0 deletions tests/plot/test_cohort_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import warnings

import pandas as pd
from IPython.display import SVG
from pandas.errors import SettingWithCopyWarning

import seismometer.plot as plot
from seismometer.api.plots import _plot_leadtime_enc


def test_plot_leadtime_enc_no_settingwithcopywarning(monkeypatch):
monkeypatch.setattr(plot, "leadtime_violin", lambda *a, **k: SVG("<svg></svg>"))

df = pd.DataFrame(
{
"cohort": ["A", "A", "B", "B"],
"event": [1, 1, 1, 1],
"time_zero": pd.to_datetime(["2025-01-01 00:00:00"] * 4),
"pred_time": pd.to_datetime(
["2025-01-01 02:00:00", "2025-01-01 03:00:00", "2025-01-01 01:00:00", "2025-01-01 04:00:00"]
),
"score": [0.9, 0.8, 0.95, 0.7],
"entity_id": [1, 1, 2, 2],
}
)

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always", SettingWithCopyWarning)

html = _plot_leadtime_enc(
dataframe=df,
entity_keys=["entity_id"],
target_event="event",
target_zero="time_zero",
score="score",
threshold=0.75,
ref_time="pred_time",
cohort_col="cohort",
subgroups=["A", "B"],
max_hours=24,
x_label="Lead Time (hours)",
censor_threshold=0,
)

assert html is not None
assert hasattr(html, "data")
assert not any(
isinstance(w.message, SettingWithCopyWarning) for w in caught
), "Expected no SettingWithCopyWarning, but one was emitted."
4 changes: 2 additions & 2 deletions tests/test_startup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from datetime import datetime
from datetime import datetime, timezone
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -41,7 +41,7 @@ def fake_seismo(tmp_path):
@patch.object(seismometer.data.loader, "loader_factory", new=fake_data_loader)
class TestStartup:
def test_debug_logs_with_formatter(self, capsys):
expected_date_str = "[" + datetime.now().strftime("%Y-%m-%d")
expected_date_str = "[" + datetime.now(timezone.utc).strftime("%Y-%m-%d")

run_startup(log_level=logging.DEBUG)

Expand Down