diff --git a/changelog/168.feature.rst b/changelog/168.feature.rst
new file mode 100644
index 00000000..e359d769
--- /dev/null
+++ b/changelog/168.feature.rst
@@ -0,0 +1 @@
+Added a 'Show raw data' checkbox to the interactive `Explore...` widgets to display the underlying pandas.DataFrame used for visualizations.
diff --git a/src/seismometer/api/explore.py b/src/seismometer/api/explore.py
index 428f0e10..d2c2ab16 100644
--- a/src/seismometer/api/explore.py
+++ b/src/seismometer/api/explore.py
@@ -1,8 +1,9 @@
from typing import Any, Optional
+import pandas as pd
from IPython.display import HTML, display
-from seismometer.controls.decorators import disk_cached_html_segment
+from seismometer.controls.decorators import disk_cached_html_and_df_segment
from seismometer.controls.explore import ExplorationWidget # noqa:
from seismometer.controls.explore import (
ExplorationCohortOutcomeInterventionEvaluationWidget,
@@ -216,9 +217,9 @@ def on_widget_value_changed(*args):
return VBox(children=[comparison_selections, output], layout=BOX_GRID_LAYOUT)
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
-def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
+def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> tuple[HTML, pd.DataFrame]:
"""
Generates an HTML table of cohort details.
@@ -229,8 +230,8 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
Returns
-------
- HTML
- able indexed by targets, with counts of unique entities, and mean values of the output columns.
+ tuple[HTML, pd.DataFrame]
+ able indexed by targets, with counts of unique entities, and mean values of the output columns, and the data
"""
from seismometer.data.filter import filter_rule_from_cohort_dictionary
@@ -246,7 +247,7 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
]
cohort_count = data[sg.entity_keys[0]].nunique()
if cohort_count < sg.censor_threshold:
- return template.render_censored_plot_message(sg.censor_threshold)
+ return template.render_censored_plot_message(sg.censor_threshold), data
groups = data.groupby(target_cols)
float_cols = list(data[intervention_cols + outcome_cols].select_dtypes(include=float))
@@ -268,7 +269,7 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
groupstats.index.rename(new_names, inplace=True)
html_table = groupstats.to_html()
title = "Summary"
- return template.render_title_message(title, html_table)
+ return template.render_title_message(title, html_table), data
# endregion
diff --git a/src/seismometer/api/plots.py b/src/seismometer/api/plots.py
index 5c88d4d4..671ae555 100644
--- a/src/seismometer/api/plots.py
+++ b/src/seismometer/api/plots.py
@@ -6,7 +6,7 @@
from IPython.display import HTML, SVG
import seismometer.plot as plot
-from seismometer.controls.decorators import disk_cached_html_segment
+from seismometer.controls.decorators import disk_cached_html_and_df_segment
from seismometer.core.autometrics import AutomationManager, store_call_parameters
from seismometer.core.decorators import export
from seismometer.data import get_cohort_data, get_cohort_performance_data, metric_apis
@@ -36,9 +36,11 @@ def plot_cohort_hist():
return _plot_cohort_hist(sg.dataframe, sg.target, sg.output, cohort_col, subgroups, censor_threshold)
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
-def plot_cohort_group_histograms(cohort_col: str, subgroups: list[str], target_column: str, score_column: str) -> HTML:
+def plot_cohort_group_histograms(
+ cohort_col: str, subgroups: list[str], target_column: str, score_column: str
+) -> tuple[HTML, pd.DataFrame]:
"""
Generate a histogram plot of predicted probabilities for each subgroup in a cohort.
@@ -55,15 +57,14 @@ def plot_cohort_group_histograms(cohort_col: str, subgroups: list[str], target_c
Returns
-------
- HTML
- html visualization of the histogram
+ tuple[HTML, pd.DataFrame]
+ html visualization of the histogram and the data used to generate it
"""
sg = Seismogram()
target_column = pdh.event_value(target_column)
return _plot_cohort_hist(sg.dataframe, target_column, score_column, cohort_col, subgroups, sg.censor_threshold)
-@disk_cached_html_segment
def _plot_cohort_hist(
dataframe: pd.DataFrame,
target: str,
@@ -71,7 +72,7 @@ def _plot_cohort_hist(
cohort_col: str,
subgroups: list[str],
censor_threshold: int = 10,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Creates an HTML segment displaying a histogram of predicted probabilities for each cohort.
@@ -107,7 +108,7 @@ def _plot_cohort_hist(
cData = cData.loc[cData["cohort"].isin(good_groups)]
if len(cData.index) == 0:
- return template.render_censored_plot_message(censor_threshold)
+ return template.render_censored_plot_message(censor_threshold), cData
bin_count = 20
bins = np.histogram_bin_edges(cData["pred"], bins=bin_count)
@@ -115,9 +116,9 @@ def _plot_cohort_hist(
try:
svg = plot.cohorts_vertical(cData, plot.histogram_stacked, func_kws={"show_legend": False, "bins": bins})
title = f"Predicted Probabilities by {cohort_col}"
- return template.render_title_with_image(title, svg)
+ return template.render_title_with_image(title, svg), cData
except Exception as error:
- return template.render_title_message("Error", f"Error: {error}")
+ return template.render_title_message("Error", f"Error: {error}"), pd.DataFrame()
@export
@@ -165,11 +166,11 @@ def plot_leadtime_enc(score=None, ref_time=None, target_event=None):
@store_call_parameters(cohort_col="cohort_col", subgroups="subgroups")
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_cohort_lead_time(
cohort_col: str, subgroups: list[str], event_column: str, score_column: str, threshold: float
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots a lead times between the first positive prediction give an threshold and an event.
@@ -214,7 +215,6 @@ def plot_cohort_lead_time(
)
-@disk_cached_html_segment
def _plot_leadtime_enc(
dataframe: pd.DataFrame,
entity_keys: list[str],
@@ -228,7 +228,7 @@ def _plot_leadtime_enc(
max_hours: int,
x_label: str,
censor_threshold: int = 10,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
HTML Plot of time between prediction and target event.
@@ -261,21 +261,24 @@ def _plot_leadtime_enc(
Returns
-------
- HTML
- Lead time plot
+ tuple[HTML, pd.DataFrame]
+ Lead time plot and the data used to generate it
"""
if target_event not in dataframe:
- logger.error(f"Target event ({target_event}) not found in dataset. Cannot plot leadtime.")
- return
+ msg = f"Target event ({target_event}) not found in dataset. Cannot plot leadtime."
+ logger.error(msg)
+ return template.render_title_message("Error", msg), pd.DataFrame()
if target_zero not in dataframe:
- logger.error(f"Target event time-zero ({target_zero}) not found in dataset. Cannot plot leadtime.")
- return
+ msg = f"Target event time-zero ({target_zero}) not found in dataset. Cannot plot leadtime."
+ logger.error(msg)
+ return template.render_title_message("Error", msg), pd.DataFrame()
summary_data = dataframe[dataframe[target_event] == 1]
if len(summary_data.index) == 0:
- logger.error(f"No positive events ({target_event}=1) were found")
- return
+ msg = f"No positive events ({target_event}=1) were found"
+ logger.error(msg)
+ return template.render_title_message("Error", msg), pd.DataFrame()
cohort_mask = summary_data[cohort_col].isin(subgroups)
threshold_mask = summary_data[score] > threshold
@@ -291,7 +294,7 @@ def _plot_leadtime_enc(
if summary_data is not None and len(summary_data) > censor_threshold:
summary_data = summary_data[[target_zero, ref_time, cohort_col]]
else:
- return template.render_censored_plot_message(censor_threshold)
+ return template.render_censored_plot_message(censor_threshold), pd.DataFrame()
# filter by group size
counts = summary_data[cohort_col].value_counts()
@@ -303,7 +306,7 @@ def _plot_leadtime_enc(
)
if len(summary_data.index) == 0:
- return template.render_censored_plot_message(censor_threshold)
+ return template.render_censored_plot_message(censor_threshold), summary_data
# Truncate to minute but plot hour
summary_data[x_label] = (summary_data[ref_time] - summary_data[target_zero]).dt.total_seconds() // 60 / 60
@@ -311,11 +314,11 @@ def _plot_leadtime_enc(
title = f'Lead Time from {score.replace("_", " ")} to {(target_zero).replace("_", " ")}'
rows = summary_data[cohort_col].nunique()
svg = plot.leadtime_violin(summary_data, x_label, cohort_col, xmax=max_hours, figsize=(9, 1 + rows))
- return template.render_title_with_image(title, svg)
+ return template.render_title_with_image(title, svg), summary_data
@store_call_parameters(cohort_col="cohort_col", subgroups="subgroups")
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_cohort_evaluation(
cohort_col: str,
@@ -324,7 +327,7 @@ def plot_cohort_evaluation(
score_column: str,
thresholds: list[float],
per_context: bool = False,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots model performance metrics split by on a cohort attribute.
@@ -364,7 +367,6 @@ def plot_cohort_evaluation(
)
-@disk_cached_html_segment
def _plot_cohort_evaluation(
dataframe: pd.DataFrame,
entity_keys: list[str],
@@ -378,7 +380,7 @@ def _plot_cohort_evaluation(
aggregation_method: str = "max",
threshold_col: str = "Threshold",
ref_time: str = None,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots model performance metrics split by on a cohort attribute.
@@ -441,14 +443,14 @@ def _plot_cohort_evaluation(
try:
assert_valid_performance_metrics_df(plot_data)
except ValueError:
- return template.render_censored_plot_message(censor_threshold)
+ return template.render_censored_plot_message(censor_threshold), plot_data
svg = plot.cohort_evaluation_vs_threshold(plot_data, cohort_feature=cohort_col, highlight=thresholds)
title = f"Model Performance Metrics on {cohort_col} across Thresholds"
- return template.render_title_with_image(title, svg)
+ return template.render_title_with_image(title, svg), plot_data
@store_call_parameters(cohort_dict="cohort_dict")
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_model_evaluation(
cohort_dict: dict[str, tuple[Any]],
@@ -456,7 +458,7 @@ def plot_model_evaluation(
score_column: str,
thresholds: list[float],
per_context: bool = False,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Generates a 2x3 plot showing the performance of a model.
@@ -503,7 +505,6 @@ def plot_model_evaluation(
)
-@disk_cached_html_segment
def _model_evaluation(
dataframe: pd.DataFrame,
entity_keys: list[str],
@@ -516,7 +517,7 @@ def _model_evaluation(
aggregation_method: str = "max",
ref_time: Optional[str] = None,
cohort: dict = {},
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
plots common model evaluation metrics
@@ -563,10 +564,13 @@ def _model_evaluation(
requirements = FilterRule.isin(target, (0, 1)) & FilterRule.notna(score_col)
data = requirements.filter(data)
if len(data.index) < censor_threshold:
- return template.render_censored_plot_message(censor_threshold)
+ return template.render_censored_plot_message(censor_threshold), data
if (lcount := data[target].nunique()) != 2:
- return template.render_title_message(
- "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}"
+ return (
+ template.render_title_message(
+ "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}"
+ ),
+ data,
)
# stats and ci handle percentile/percentage independently - evaluation wants 0-100 for displays
@@ -595,7 +599,7 @@ def _model_evaluation(
show_thresholds=True,
highlight=thresholds,
)
- return template.render_title_with_image(title, svg)
+ return template.render_title_with_image(title, svg), data
@store_call_parameters
@@ -622,7 +626,7 @@ def plot_trend_intervention_outcome() -> HTML:
)
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_intervention_outcome_timeseries(
cohort_col: str,
@@ -631,7 +635,7 @@ def plot_intervention_outcome_timeseries(
intervention: str,
reference_time_col: str,
censor_threshold: int = 10,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots two timeseries based on an outcome and an intervention.
@@ -666,7 +670,6 @@ def plot_intervention_outcome_timeseries(
)
-@disk_cached_html_segment
def _plot_trend_intervention_outcome(
dataframe: pd.DataFrame,
entity_keys: list[str],
@@ -676,7 +679,7 @@ def _plot_trend_intervention_outcome(
intervention: str,
reftime: str,
censor_threshold: int = 10,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots two timeseries based on selectors; an outcome and then an intervention.
@@ -748,7 +751,7 @@ def _plot_trend_intervention_outcome(
"Missing Outcome", f"No outcome timeseries plotted; No events with name {outcome}."
)
- return HTML(outcome_plot.data + intervention_plot.data)
+ return HTML(outcome_plot.data + intervention_plot.data), dataframe
def _plot_ts_cohort(
@@ -842,11 +845,11 @@ def _plot_ts_cohort(
@store_call_parameters(cohort_dict="cohort_dict")
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_model_score_comparison(
cohort_dict: dict[str, tuple[Any]], target: str, scores: tuple[str], *, per_context: bool
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots a comparison of model scores for a given subpopulation.
@@ -903,17 +906,17 @@ def plot_model_score_comparison(
try:
assert_valid_performance_metrics_df(plot_data)
except ValueError:
- return template.render_censored_plot_message(sg.censor_threshold)
+ return template.render_censored_plot_message(sg.censor_threshold), plot_data
svg = plot.cohort_evaluation_vs_threshold(plot_data, cohort_feature="ScoreName")
title = f"Model Metrics: {', '.join(scores)} vs {target}"
- return template.render_title_with_image(title, svg)
+ return template.render_title_with_image(title, svg), plot_data
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_model_target_comparison(
cohort_dict: dict[str, tuple[Any]], targets: tuple[str], score: str, *, per_context: bool
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Plots a comparison of model targets for a given score and subpopulation.
@@ -970,15 +973,15 @@ def plot_model_target_comparison(
try:
assert_valid_performance_metrics_df(plot_data)
except ValueError:
- return template.render_censored_plot_message(sg.censor_threshold)
+ return template.render_censored_plot_message(sg.censor_threshold), plot_data
svg = plot.cohort_evaluation_vs_threshold(plot_data, cohort_feature="ScoreName")
title = f"Model Metrics: {', '.join(targets)} vs {score}"
- return template.render_title_with_image(title, svg)
+ return template.render_title_with_image(title, svg), plot_data
# region Explore Any Metric (NNT, etc)
@store_call_parameters(cohort_dict="cohort_dict")
-@disk_cached_html_segment
+@disk_cached_html_and_df_segment
@export
def plot_binary_classifier_metrics(
metric_generator: BinaryClassifierMetricGenerator,
@@ -989,7 +992,7 @@ def plot_binary_classifier_metrics(
*,
per_context: bool = False,
table_only: bool = False,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
Generates a plot with model metrics.
@@ -1074,7 +1077,7 @@ def binary_classifier_metric_evaluation(
aggregation_method: str = "max",
ref_time: str = None,
table_only: bool = False,
-) -> HTML:
+) -> tuple[HTML, pd.DataFrame]:
"""
plots common model evaluation metrics
@@ -1118,10 +1121,13 @@ def binary_classifier_metric_evaluation(
requirements = FilterRule.isin(target, (0, 1)) & FilterRule.notna(score_col)
data = requirements.filter(data)
if len(data.index) < censor_threshold:
- return template.render_censored_plot_message(censor_threshold)
+ return template.render_censored_plot_message(censor_threshold), data
if (lcount := data[target].nunique()) != 2:
- return template.render_title_message(
- "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}"
+ return (
+ template.render_title_message(
+ "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}"
+ ),
+ data,
)
if isinstance(metrics, str):
metrics = [metrics]
@@ -1134,8 +1140,8 @@ def binary_classifier_metric_evaluation(
if log_all:
recorder.populate_metrics(attributes=attributes, metrics={metric: stats[metric].to_dict()})
if table_only:
- return HTML(stats[metrics].T.to_html())
- return plot.binary_classifier.plot_metric_list(stats, metrics)
+ return HTML(stats[metrics].T.to_html()), data
+ return plot.binary_classifier.plot_metric_list(stats, metrics), data
# endregion
diff --git a/src/seismometer/controls/decorators.py b/src/seismometer/controls/decorators.py
index eaac8cea..0dfaa13e 100644
--- a/src/seismometer/controls/decorators.py
+++ b/src/seismometer/controls/decorators.py
@@ -5,6 +5,7 @@
import hashlib
import logging
import os
+import pickle
import shutil
from functools import wraps
from inspect import signature
@@ -27,4 +28,29 @@ def html_save(html, filepath) -> None:
filepath.write_text(html.data)
+def html_and_df_save(data, filepath) -> None:
+ """
+ Saves a tuple of (HTML, pd.DataFrame) to disk.
+ """
+ html, df = data
+ html_path = filepath.with_suffix(".html")
+ df_path = filepath.with_suffix(".pkl")
+ html_path.write_text(html.data)
+ with open(df_path, "wb") as f:
+ pickle.dump(df, f)
+
+
+def html_and_df_load(filepath) -> tuple[HTML, Any]:
+ """
+ Loads a tuple of (HTML, pd.DataFrame) from disk.
+ """
+ html_path = filepath.with_suffix(".html")
+ df_path = filepath.with_suffix(".pkl")
+ html = HTML(html_path.read_text())
+ with open(df_path, "rb") as f:
+ df = pickle.load(f)
+ return html, df
+
+
disk_cached_html_segment = DiskCachedFunction("html", save_fn=html_save, load_fn=html_load, return_type=HTML)
+disk_cached_html_and_df_segment = DiskCachedFunction("html_and_df", save_fn=html_and_df_save, load_fn=html_and_df_load)
diff --git a/src/seismometer/controls/explore.py b/src/seismometer/controls/explore.py
index 09d4f020..ae0fa316 100644
--- a/src/seismometer/controls/explore.py
+++ b/src/seismometer/controls/explore.py
@@ -35,6 +35,7 @@ class UpdatePlotWidget(Box):
UPDATE_PLOTS = "Update"
UPDATING_PLOTS = "Updating ..."
SHOW_CODE = "Show code?"
+ SHOW_DATA = "Show raw data?"
def __init__(self):
self.code_checkbox = Checkbox(
@@ -45,10 +46,18 @@ def __init__(self):
tooltip="Show the code used to generate the plot.",
layout=Layout(margin="var(--jp-widgets-margin) var(--jp-widgets-margin) var(--jp-widgets-margin) 10px;"),
)
+ self.data_checkbox = Checkbox(
+ value=False,
+ description=self.SHOW_DATA,
+ disabled=False,
+ indent=False,
+ tooltip="Show the raw data used to generate the plot.",
+ layout=Layout(margin="var(--jp-widgets-margin) var(--jp-widgets-margin) var(--jp-widgets-margin) 10px;"),
+ )
self.plot_button = Button(description=self.UPDATE_PLOTS, button_style="primary")
layout = Layout(align_items="flex-start")
- children = [self.plot_button, self.code_checkbox]
+ children = [self.plot_button, self.code_checkbox, self.data_checkbox]
super().__init__(layout=layout, children=children)
@property
@@ -59,6 +68,14 @@ def show_code(self) -> bool:
def show_code(self, show_code: bool) -> bool:
self.code_checkbox.value = show_code
+ @property
+ def show_data(self) -> bool:
+ return self.data_checkbox.value
+
+ @show_data.setter
+ def show_data(self, show_data: bool) -> bool:
+ self.data_checkbox.value = show_data
+
@property
def disabled(self) -> bool:
return self.plot_button.disabled
@@ -86,6 +103,13 @@ def callback_wrapper(change):
self.code_checkbox.observe(callback_wrapper, "value")
+ def on_toggle_data(self, callback):
+ @wraps(callback)
+ def callback_wrapper(change):
+ callback(self.data_checkbox.value)
+
+ self.data_checkbox.observe(callback_wrapper, "value")
+
class ModelOptionsWidget(VBox, ValueWidget):
value = traitlets.Dict(help="The selected values for the model options.")
@@ -874,11 +898,20 @@ def __init__(
title = HTML(value=f"""
{title}
""")
self.center = Output(layout=Layout(height="max-content", max_width="1200px"))
self.code_output = Output(layout=Layout(height="max-content", max_width="1200px"))
+ self.data_output = Output(layout=Layout(height="max-content", max_width="1200px"))
self.option_widget = option_widget
self.plot_function = plot_function
self.update_plot_widget = UpdatePlotWidget()
super().__init__(
- children=[title, self.option_widget, self.update_plot_widget, self.code_output, self.center], layout=layout
+ children=[
+ title,
+ self.option_widget,
+ self.update_plot_widget,
+ self.code_output,
+ self.data_output,
+ self.center,
+ ],
+ layout=layout,
)
# show initial plot
@@ -886,11 +919,13 @@ def __init__(
self.update_plot(initial=True)
else:
self.current_plot_code = self.NO_CODE_STRING
+ self.current_plot_data = None
# attache event handlers
self.option_widget.observe(self._on_option_change, "value")
self.update_plot_widget.on_click(self._on_plot_button_click)
self.update_plot_widget.on_toggle_code(self._on_toggle_code)
+ self.update_plot_widget.on_toggle_data(self._on_toggle_data)
@property
def disabled(self) -> bool:
@@ -910,6 +945,17 @@ def show_code(self) -> bool:
def show_code(self, show_code: bool):
self.update_plot_widget.show_code = show_code
+ @property
+ def show_data(self) -> bool:
+ """
+ If the widget should show the plot's data.
+ """
+ return self.update_plot_widget.show_data
+
+ @show_data.setter
+ def show_data(self, show_data: bool):
+ self.update_plot_widget.show_data = show_data
+
@staticmethod
def _is_interactive_notebook() -> bool:
ip = get_ipython()
@@ -937,10 +983,12 @@ def _try_generate_plot(self) -> Any:
try:
plot_args, plot_kwargs = self.generate_plot_args()
self.current_plot_code = self.generate_plot_code(plot_args, plot_kwargs)
- plot = self.plot_function(*plot_args, **plot_kwargs)
+ # The plot function will now return a tuple (plot, data)
+ plot, self.current_plot_data = self.plot_function(*plot_args, **plot_kwargs)
except Exception as e:
import traceback
+ self.current_plot_data = None
plot = HTML(f"Exception: {e}
{traceback.format_exc()} ")
return plot
@@ -949,6 +997,7 @@ def _on_plot_button_click(self, button=None):
self.option_widget.disabled = True
self.update_plot()
self._on_toggle_code(self.show_code)
+ self._on_toggle_data(self.show_data)
self.option_widget.disabled = False
def generate_plot_code(self, plot_args: tuple = None, plot_kwargs: dict = None) -> str:
@@ -987,6 +1036,16 @@ def _on_toggle_code(self, show_code: bool):
with self.code_output:
display(highlighted_code)
+ def _on_toggle_data(self, show_data: bool):
+ """Handle for the toggle data checkbox."""
+ self.data_output.clear_output()
+ if not show_data:
+ return
+
+ if self.current_plot_data is not None:
+ with self.data_output:
+ display(self.current_plot_data)
+
def _on_option_change(self, change=None):
"""Enable the plot to be updated."""
self.update_plot_widget.disabled = self.disabled
diff --git a/tests/api/test_api_explore.py b/tests/api/test_api_explore.py
index 99ecab62..94efdd0e 100644
--- a/tests/api/test_api_explore.py
+++ b/tests/api/test_api_explore.py
@@ -148,7 +148,7 @@ def test_cohort_list_details_summary_generated(self, fake_seismo):
mock_filter.assert_called_once()
mock_render.assert_called_once()
- assert "mock summary" in result.data
+ assert "mock summary" in result[0].data
def test_cohort_list_details_censored_output(self, fake_seismo):
fake_seismo.config.censor_min_count = 10
@@ -161,7 +161,7 @@ def test_cohort_list_details_censored_output(self, fake_seismo):
rule.filter.return_value = fake_seismo.dataframe.iloc[:0] # no rows
mock_filter.return_value = rule
- result = cohort_list_details({"Cohort": ["C1", "C2"]})
+ result, _ = cohort_list_details({"Cohort": ["C1", "C2"]})
assert "censored" in result.data.lower()
mock_render.assert_called_once()
@@ -175,7 +175,7 @@ def test_cohort_list_details_single_target_index_rename(self, fake_seismo):
rule.filter.return_value = fake_seismo.dataframe
mock_filter.return_value = rule
- result = cohort_list_details({"Cohort": ["C1", "C2"]})
+ result, _ = cohort_list_details({"Cohort": ["C1", "C2"]})
assert "summary" in result.data
@@ -250,15 +250,15 @@ def test_plot_binary_classifier_metrics_basic(self, mock_calc, mock_plot, fake_s
mock_calc.assert_called_once()
mock_plot.assert_called_once()
- assert isinstance(result, SVG)
- assert "http://www.w3.org/2000/svg" in result.data
+ assert isinstance(result[0], SVG)
+ assert "http://www.w3.org/2000/svg" in result[0].data
@patch("seismometer.data.performance.BinaryClassifierMetricGenerator.calculate_binary_stats")
def test_plot_binary_classifier_metrics_table_only(self, mock_calc, fake_seismo):
mock_stats = pd.DataFrame({"Sensitivity": [0.88]}, index=["value"])
mock_calc.return_value = (mock_stats, None)
- result = plot_binary_classifier_metrics(
+ result, _ = plot_binary_classifier_metrics(
metric_generator=BinaryClassifierMetricGenerator(rho=0.5),
metrics="Sensitivity",
cohort_dict={},
@@ -284,7 +284,7 @@ def test_plot_binary_classifier_metrics_nonbinary_target(
with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo):
fake_seismo.dataframe = df
- result = plot_binary_classifier_metrics(
+ result, _ = plot_binary_classifier_metrics(
metric_generator=BinaryClassifierMetricGenerator(rho=0.5),
metrics=["Accuracy"],
cohort_dict={"Cohort": ("C1",)},
@@ -304,7 +304,7 @@ def test_model_evaluation_not_binary_target(self, mock_filter, mock_render, fake
df["event1_Value"] = 1 # single class
mock_filter.return_value = df
- result = _model_evaluation(
+ result, _ = _model_evaluation(
df,
entity_keys=["entity"],
target_event="event1",
@@ -341,7 +341,7 @@ def test_model_evaluation_valid_path(self, mock_filter, mock_render, mock_plot,
mock_filter.return_value = df
mock_scores.return_value = df
- result = _model_evaluation(
+ result, _ = _model_evaluation(
df,
entity_keys=["entity"],
target_event="event1",
@@ -388,7 +388,7 @@ def test_plot_model_score_comparison(
mock_event_score.return_value = df # per_context=False path
mock_perf_data.return_value = _mock_perf_df
- result = plot_model_score_comparison(
+ result, _ = plot_model_score_comparison(
cohort_dict={"Cohort": ("C1", "C2")},
target="event1",
scores=("score1",),
@@ -431,7 +431,7 @@ def test_plot_model_target_comparison(
mock_event_score.return_value = df
mock_perf_data.return_value = _mock_perf_df
- result = plot_model_target_comparison(
+ result, _ = plot_model_target_comparison(
cohort_dict={"Cohort": ("C1", "C2")},
targets=("event1",),
score="score1",
@@ -461,7 +461,7 @@ def test_plot_cohort_evaluation_censored_data(
mock_get_scores.return_value = df
mock_get_perf.return_value = _mock_perf_df
- result = _plot_cohort_evaluation(
+ result, _ = _plot_cohort_evaluation(
dataframe=_mock_perf_df,
entity_keys=["entity"],
target="event1_Value",
@@ -494,7 +494,7 @@ def test_plot_cohort_evaluation_success(
mock_get_scores.return_value = df
mock_get_perf.return_value = _mock_perf_df
- result = _plot_cohort_evaluation(
+ result, _ = _plot_cohort_evaluation(
dataframe=_mock_perf_df,
entity_keys=["entity"],
target="event1_Value",
@@ -524,7 +524,7 @@ def test_plot_cohort_group_histograms(
mock_seismo.return_value = fake_seismo
mock_filter.return_value = fake_seismo.dataframe
- result = plot_cohort_group_histograms(
+ result, _ = plot_cohort_group_histograms(
cohort_col="Cohort",
subgroups=["C1", "C2"],
target_column="event1",
@@ -540,7 +540,7 @@ def test_plot_cohort_hist_censored_after_filter(self, mock_render, fake_seismo):
# Simulate filtered-out result
empty_df = fake_seismo.dataframe.iloc[0:0].copy()
- result = _plot_cohort_hist(
+ result, _ = _plot_cohort_hist(
dataframe=empty_df,
target="event1_Value",
output="score1",
@@ -560,7 +560,7 @@ def test_plot_cohort_hist_plot_fails(self, mock_get_cohort_data, mock_plot, mock
mock_df = pd.DataFrame({"cohort": ["C1"] * 11 + ["C2"] * 11, "pred": [0.1] * 11 + [0.2] * 11})
mock_get_cohort_data.return_value = mock_df
- result = _plot_cohort_hist(
+ result, _ = _plot_cohort_hist(
dataframe=fake_seismo.dataframe,
target="event1_Value",
output="score1",
@@ -610,7 +610,7 @@ def test_plot_cohort_lead_time(self, mock_plot, mock_seismo, fake_seismo):
def test_leadtime_enc_missing_target_column(self, fake_seismo, caplog):
df = fake_seismo.dataframe.drop(columns=["event1_Value"])
with caplog.at_level("ERROR"):
- result = _plot_leadtime_enc(
+ result, _ = _plot_leadtime_enc(
df,
entity_keys=["entity"],
target_event="event1_Value",
@@ -623,13 +623,13 @@ def test_leadtime_enc_missing_target_column(self, fake_seismo, caplog):
max_hours=48,
x_label="Lead Time (hours)",
)
- assert result is None
+ assert "Error" in result.data
assert "Target event (event1_Value) not found" in caplog.text
def test_leadtime_enc_missing_target_zero_column(self, fake_seismo, caplog):
df = fake_seismo.dataframe.drop(columns=["event1_Time"])
with caplog.at_level("ERROR"):
- result = _plot_leadtime_enc(
+ result, _ = _plot_leadtime_enc(
df,
entity_keys=["entity"],
target_event="event1_Value",
@@ -642,14 +642,14 @@ def test_leadtime_enc_missing_target_zero_column(self, fake_seismo, caplog):
max_hours=48,
x_label="Lead Time (hours)",
)
- assert result is None
+ assert "Error" in result.data
assert "Target event time-zero (event1_Time) not found" in caplog.text
def test_leadtime_enc_no_positive_events(self, fake_seismo, caplog):
df = fake_seismo.dataframe.copy()
df["event1_Value"] = 0 # force all negative
with caplog.at_level("ERROR"):
- result = _plot_leadtime_enc(
+ result, _ = _plot_leadtime_enc(
df,
entity_keys=["entity"],
target_event="event1_Value",
@@ -662,7 +662,7 @@ def test_leadtime_enc_no_positive_events(self, fake_seismo, caplog):
max_hours=48,
x_label="Lead Time (hours)",
)
- assert result is None
+ assert "Error" in result.data
assert "No positive events (event1_Value=1) were found" in caplog.text
@patch("seismometer.api.plots.pdh.event_score")
@@ -673,7 +673,7 @@ def test_leadtime_enc_subgroup_filter_excludes_all(self, mock_render, mock_filte
mock_filter.return_value = df[:0]
mock_score.return_value = df[:0]
- result = _plot_leadtime_enc(
+ result, _ = _plot_leadtime_enc(
df,
entity_keys=["entity"],
target_event="event1_Value",
@@ -714,7 +714,7 @@ def test_plot_trend_intervention_outcome_combines_both(
self, mock_render, mock_plot, mock_event_value, fake_seismo
):
fake_seismo.selected_cohort = ("Cohort", ["C1", "C2"])
- result = plot_trend_intervention_outcome()
+ result, _ = plot_trend_intervention_outcome()
assert isinstance(result, HTML)
assert "Outcome" in result.data
assert "Intervention" in result.data
@@ -728,7 +728,7 @@ def test_plot_trend_intervention_outcome_missing_intervention(
self, mock_msg, mock_render, mock_plot, mock_event_value, fake_seismo
):
fake_seismo.selected_cohort = ("Cohort", ["C1", "C2"])
- result = plot_trend_intervention_outcome()
+ result, _ = plot_trend_intervention_outcome()
assert isinstance(result, HTML)
assert "Missing Intervention" in result.data
assert "Outcome" in result.data
@@ -741,7 +741,7 @@ def test_plot_trend_intervention_outcome_missing_outcome(
self, mock_msg, mock_render, mock_plot, mock_event_value, fake_seismo
):
fake_seismo.selected_cohort = ("Cohort", ["C1", "C2"])
- result = plot_trend_intervention_outcome()
+ result, _ = plot_trend_intervention_outcome()
assert isinstance(result, HTML)
assert "Missing Outcome" in result.data
assert "Intervention" in result.data
diff --git a/tests/controls/test_explore.py b/tests/controls/test_explore.py
index a5bb0cb5..5a1d6583 100644
--- a/tests/controls/test_explore.py
+++ b/tests/controls/test_explore.py
@@ -29,6 +29,8 @@ def test_init(self):
assert widget.plot_button.disabled is False
assert widget.code_checkbox.description == widget.SHOW_CODE
assert widget.show_code is False
+ assert widget.data_checkbox.description == widget.SHOW_DATA
+ assert widget.show_data is False
def test_plot_button_click(self):
count = 0
@@ -68,15 +70,28 @@ def on_toggle_callback(button):
widget.code_checkbox.value = True
assert count == 1
+ def test_toggle_data_checkbox(self):
+ count = 0
+
+ def on_toggle_callback(button):
+ nonlocal count
+ count += 1
+
+ widget = undertest.UpdatePlotWidget()
+ widget.on_toggle_data(on_toggle_callback)
+ widget.data_checkbox.value = True
+ assert count == 1
+
class TestExplorationBaseClass:
def test_base_class(self, caplog):
option_widget = ipywidgets.Checkbox(description="ClickMe")
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
widget = undertest.ExplorationWidget("ExploreTest", option_widget, plot_function)
plot_function.assert_not_called()
- assert "Subclasses must implement this method" in widget._try_generate_plot().value
+ plot = widget._try_generate_plot()
+ assert "Subclasses must implement this method" in plot.value
@pytest.mark.parametrize(
"plot_module,plot_code",
@@ -88,7 +103,7 @@ def test_base_class(self, caplog):
)
def test_args_subclass(self, plot_module, plot_code):
option_widget = ipywidgets.Checkbox(description="ClickMe")
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_something"
plot_function.__module__ = plot_module
@@ -103,11 +118,13 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
plot_function.assert_called_once_with(False)
assert widget.current_plot_code == plot_code
assert widget.show_code is False
+ assert widget.show_data is False
assert widget._try_generate_plot() == "some result"
+ assert widget.current_plot_data == "some data"
def test_kwargs_subclass(self):
option_widget = ipywidgets.Checkbox(description="ClickMe")
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_something"
plot_function.__module__ = "test_explore"
@@ -122,11 +139,13 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
plot_function.assert_called_once_with(checkbox=False)
assert widget.current_plot_code == "test_explore.plot_something(checkbox=False)"
assert widget.show_code is False
+ assert widget.show_data is False
assert widget._try_generate_plot() == "some result"
+ assert widget.current_plot_data == "some data"
def test_args_kwargs_subclass(self):
option_widget = ipywidgets.Checkbox(description="ClickMe")
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_something"
plot_function.__module__ = "test_explore"
@@ -141,7 +160,9 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
plot_function.assert_called_once_with("test", checkbox=False)
assert widget.current_plot_code == "test_explore.plot_something('test', checkbox=False)"
assert widget.show_code is False
+ assert widget.show_data is False
assert widget._try_generate_plot() == "some result"
+ assert widget.current_plot_data == "some data"
def test_exception_plot_code_subclass(self):
option_widget = ipywidgets.Checkbox(description="ClickMe")
@@ -160,10 +181,11 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
plot = widget._try_generate_plot()
assert "Traceback" in plot.value
assert "Test Exception" in plot.value
+ assert widget.current_plot_data is None
def test_no_initial_plot_subclass(self):
option_widget = ipywidgets.Checkbox(description="ClickMe")
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_something"
plot_function.__module__ = "test_explore"
@@ -178,13 +200,16 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
widget.center.outputs == []
plot_function.assert_not_called()
assert widget.current_plot_code == ExploreFake.NO_CODE_STRING
+ assert widget.current_plot_data is None
assert widget.show_code is False
- widget._try_generate_plot() == ""
+ assert widget.show_data is False
+ assert widget._try_generate_plot() == "some result"
+ assert widget.current_plot_data == "some data"
@pytest.mark.parametrize("show_code", [True, False])
def test_toggle_code_callback(self, show_code, capsys):
option_widget = ipywidgets.Checkbox(description="ClickMe")
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_something"
plot_function.__module__ = "test_explore"
@@ -204,6 +229,29 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
code_in_output = "test_explore.plot_something" in stdout.split("\n")[-2]
assert code_in_output == show_code
+ @pytest.mark.parametrize("show_data", [True, False])
+ def test_toggle_data_callback(self, show_data, capsys):
+ option_widget = ipywidgets.Checkbox(description="ClickMe")
+ plot_function = Mock(return_value=("some result", "some data"))
+ plot_function.__name__ = "plot_something"
+ plot_function.__module__ = "test_explore"
+
+ class ExploreFake(undertest.ExplorationWidget):
+ def __init__(self):
+ super().__init__("Fake Explorer", option_widget, plot_function)
+
+ def generate_plot_args(self) -> tuple[tuple, dict]:
+ return ["test"], {"checkbox": self.option_widget.value}
+
+ widget = ExploreFake()
+ widget.show_data = show_data
+ widget.data_output = MagicMock()
+ widget._on_plot_button_click()
+ stdout = capsys.readouterr().out
+ assert "some result" in stdout
+ data_in_output = "some data" in stdout.split("\n")[-2]
+ assert data_in_output == show_data
+
# endregion
# region Test Model Options Widgets
@@ -559,7 +607,7 @@ class TestExplorationSubpopulationWidget:
def test_init(self, mock_seismo):
fake_seismo = mock_seismo()
fake_seismo.available_cohort_groups = {"C1": ["C1.1", "C1.2"], "C2": ["C2.1", "C2.2"]}
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -569,12 +617,13 @@ def test_init(self, mock_seismo):
assert widget.update_plot_widget.disabled
assert widget.current_plot_code == "test_explore.plot_function({})"
plot_function.assert_called_once_with({}) # default value
+ assert widget.current_plot_data == "some data"
@patch.object(seismogram, "Seismogram", return_value=Mock())
def test_option_update(self, mock_seismo):
fake_seismo = mock_seismo()
fake_seismo.available_cohort_groups = {"C1": ["C1.1", "C1.2"], "C2": ["C2.1", "C2.2"]}
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -593,6 +642,7 @@ def test_option_update(self, mock_seismo):
]
}
) # updated value
+ assert widget.current_plot_data == "some data"
class TestExplorationModelSubgroupEvaluationWidget:
@@ -604,7 +654,7 @@ def test_init(self, mock_seismo):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -616,6 +666,7 @@ def test_init(self, mock_seismo):
assert widget.update_plot_widget.disabled
assert widget.current_plot_code == "test_explore.plot_function({}, 'T1', 'S1', [0.2, 0.1], per_context=False)"
plot_function.assert_called_once_with({}, "T1", "S1", [0.2, 0.1], per_context=False) # default value
+ assert widget.current_plot_data == "some data"
@patch.object(seismogram, "Seismogram", return_value=Mock())
def test_option_update(self, mock_seismo):
@@ -625,7 +676,7 @@ def test_option_update(self, mock_seismo):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -636,6 +687,7 @@ def test_option_update(self, mock_seismo):
widget.option_widget.cohort_list.value = {"C2": ("C2.1",)}
widget.update_plot()
plot_function.assert_called_with({"C2": ("C2.1",)}, "T1", "S1", [0.2, 0.1], per_context=False) # updated value
+ assert widget.current_plot_data == "some data"
class TestExplorationCohortSubclassEvaluationWidget:
@@ -651,7 +703,7 @@ def test_init(self, mock_seismo, threshold_handling, thresholds):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -673,6 +725,7 @@ def test_init(self, mock_seismo, threshold_handling, thresholds):
"C1", ("C1.1", "C1.2"), "T1", "S1", per_context=False
) # default value
assert widget.current_plot_code == expected_code
+ assert widget.current_plot_data == "some data"
class TestExplorationCohortOutcomeInterventionEvaluationWidget:
@@ -688,7 +741,7 @@ def test_init(self, mock_seismo):
)
fake_seismo.config = fake_config
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -702,6 +755,7 @@ def test_init(self, mock_seismo):
widget.current_plot_code == "test_explore.plot_function('C1', ('C1.1', 'C1.2'), 'O1', 'I1', 'pred_time')"
)
plot_function.assert_called_once_with("C1", ("C1.1", "C1.2"), "O1", "I1", "pred_time") # default value
+ assert widget.current_plot_data == "some data"
class TestExplorationScoreComparisonByCohortWidget:
@@ -713,7 +767,7 @@ def test_init(self, mock_seismo):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -725,6 +779,7 @@ def test_init(self, mock_seismo):
assert widget.update_plot_widget.disabled
assert widget.current_plot_code == "test_explore.plot_function({}, 'T1', ('S1', 'S2'), per_context=False)"
plot_function.assert_called_once_with({}, "T1", ("S1", "S2"), per_context=False) # default value
+ assert widget.current_plot_data == "some data"
class TestExplorationTargetComparisonByCohortWidget:
@@ -736,7 +791,7 @@ def test_init(self, mock_seismo):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -748,6 +803,7 @@ def test_init(self, mock_seismo):
assert widget.update_plot_widget.disabled
assert widget.current_plot_code == "test_explore.plot_function({}, ('T1', 'T2'), 'S1', per_context=False)"
plot_function.assert_called_once_with({}, ("T1", "T2"), "S1", per_context=False) # default value
+ assert widget.current_plot_data == "some data"
# endregion
@@ -804,7 +860,7 @@ def test_init(self, mock_seismo):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -823,6 +879,7 @@ def test_init(self, mock_seismo):
plot_function.assert_called_once_with(
metric_generator, ("Precision", "Recall"), {}, "T1", "S1", per_context=False
)
+ assert widget.current_plot_data == "some data"
@patch.object(seismogram, "Seismogram", return_value=Mock())
def test_option_update(self, mock_seismo):
@@ -833,7 +890,7 @@ def test_option_update(self, mock_seismo):
fake_seismo.target_cols = ["T1", "T2"]
fake_seismo.output_list = ["S1", "S2"]
- plot_function = Mock(return_value="some result")
+ plot_function = Mock(return_value=("some result", "some data"))
plot_function.__name__ = "plot_function"
plot_function.__module__ = "test_explore"
@@ -853,6 +910,7 @@ def test_option_update(self, mock_seismo):
+ "('F1',), {}, 'T2', 'S2', per_context=False)"
)
plot_function.assert_called_with(metric_generator, ("F1",), {}, "T2", "S2", per_context=False)
+ assert widget.current_plot_data == "some data"
class TestExploreBinaryModelAnalytics: