diff --git a/changelog/168.feature.rst b/changelog/168.feature.rst new file mode 100644 index 00000000..e359d769 --- /dev/null +++ b/changelog/168.feature.rst @@ -0,0 +1 @@ +Added a 'Show raw data' checkbox to the interactive `Explore...` widgets to display the underlying pandas.DataFrame used for visualizations. diff --git a/src/seismometer/api/explore.py b/src/seismometer/api/explore.py index 428f0e10..d2c2ab16 100644 --- a/src/seismometer/api/explore.py +++ b/src/seismometer/api/explore.py @@ -1,8 +1,9 @@ from typing import Any, Optional +import pandas as pd from IPython.display import HTML, display -from seismometer.controls.decorators import disk_cached_html_segment +from seismometer.controls.decorators import disk_cached_html_and_df_segment from seismometer.controls.explore import ExplorationWidget # noqa: from seismometer.controls.explore import ( ExplorationCohortOutcomeInterventionEvaluationWidget, @@ -216,9 +217,9 @@ def on_widget_value_changed(*args): return VBox(children=[comparison_selections, output], layout=BOX_GRID_LAYOUT) -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export -def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML: +def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> tuple[HTML, pd.DataFrame]: """ Generates an HTML table of cohort details. @@ -229,8 +230,8 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML: Returns ------- - HTML - able indexed by targets, with counts of unique entities, and mean values of the output columns. + tuple[HTML, pd.DataFrame] + able indexed by targets, with counts of unique entities, and mean values of the output columns, and the data """ from seismometer.data.filter import filter_rule_from_cohort_dictionary @@ -246,7 +247,7 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML: ] cohort_count = data[sg.entity_keys[0]].nunique() if cohort_count < sg.censor_threshold: - return template.render_censored_plot_message(sg.censor_threshold) + return template.render_censored_plot_message(sg.censor_threshold), data groups = data.groupby(target_cols) float_cols = list(data[intervention_cols + outcome_cols].select_dtypes(include=float)) @@ -268,7 +269,7 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML: groupstats.index.rename(new_names, inplace=True) html_table = groupstats.to_html() title = "Summary" - return template.render_title_message(title, html_table) + return template.render_title_message(title, html_table), data # endregion diff --git a/src/seismometer/api/plots.py b/src/seismometer/api/plots.py index 5c88d4d4..671ae555 100644 --- a/src/seismometer/api/plots.py +++ b/src/seismometer/api/plots.py @@ -6,7 +6,7 @@ from IPython.display import HTML, SVG import seismometer.plot as plot -from seismometer.controls.decorators import disk_cached_html_segment +from seismometer.controls.decorators import disk_cached_html_and_df_segment from seismometer.core.autometrics import AutomationManager, store_call_parameters from seismometer.core.decorators import export from seismometer.data import get_cohort_data, get_cohort_performance_data, metric_apis @@ -36,9 +36,11 @@ def plot_cohort_hist(): return _plot_cohort_hist(sg.dataframe, sg.target, sg.output, cohort_col, subgroups, censor_threshold) -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export -def plot_cohort_group_histograms(cohort_col: str, subgroups: list[str], target_column: str, score_column: str) -> HTML: +def plot_cohort_group_histograms( + cohort_col: str, subgroups: list[str], target_column: str, score_column: str +) -> tuple[HTML, pd.DataFrame]: """ Generate a histogram plot of predicted probabilities for each subgroup in a cohort. @@ -55,15 +57,14 @@ def plot_cohort_group_histograms(cohort_col: str, subgroups: list[str], target_c Returns ------- - HTML - html visualization of the histogram + tuple[HTML, pd.DataFrame] + html visualization of the histogram and the data used to generate it """ sg = Seismogram() target_column = pdh.event_value(target_column) return _plot_cohort_hist(sg.dataframe, target_column, score_column, cohort_col, subgroups, sg.censor_threshold) -@disk_cached_html_segment def _plot_cohort_hist( dataframe: pd.DataFrame, target: str, @@ -71,7 +72,7 @@ def _plot_cohort_hist( cohort_col: str, subgroups: list[str], censor_threshold: int = 10, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Creates an HTML segment displaying a histogram of predicted probabilities for each cohort. @@ -107,7 +108,7 @@ def _plot_cohort_hist( cData = cData.loc[cData["cohort"].isin(good_groups)] if len(cData.index) == 0: - return template.render_censored_plot_message(censor_threshold) + return template.render_censored_plot_message(censor_threshold), cData bin_count = 20 bins = np.histogram_bin_edges(cData["pred"], bins=bin_count) @@ -115,9 +116,9 @@ def _plot_cohort_hist( try: svg = plot.cohorts_vertical(cData, plot.histogram_stacked, func_kws={"show_legend": False, "bins": bins}) title = f"Predicted Probabilities by {cohort_col}" - return template.render_title_with_image(title, svg) + return template.render_title_with_image(title, svg), cData except Exception as error: - return template.render_title_message("Error", f"Error: {error}") + return template.render_title_message("Error", f"Error: {error}"), pd.DataFrame() @export @@ -165,11 +166,11 @@ def plot_leadtime_enc(score=None, ref_time=None, target_event=None): @store_call_parameters(cohort_col="cohort_col", subgroups="subgroups") -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_cohort_lead_time( cohort_col: str, subgroups: list[str], event_column: str, score_column: str, threshold: float -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots a lead times between the first positive prediction give an threshold and an event. @@ -214,7 +215,6 @@ def plot_cohort_lead_time( ) -@disk_cached_html_segment def _plot_leadtime_enc( dataframe: pd.DataFrame, entity_keys: list[str], @@ -228,7 +228,7 @@ def _plot_leadtime_enc( max_hours: int, x_label: str, censor_threshold: int = 10, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ HTML Plot of time between prediction and target event. @@ -261,21 +261,24 @@ def _plot_leadtime_enc( Returns ------- - HTML - Lead time plot + tuple[HTML, pd.DataFrame] + Lead time plot and the data used to generate it """ if target_event not in dataframe: - logger.error(f"Target event ({target_event}) not found in dataset. Cannot plot leadtime.") - return + msg = f"Target event ({target_event}) not found in dataset. Cannot plot leadtime." + logger.error(msg) + return template.render_title_message("Error", msg), pd.DataFrame() if target_zero not in dataframe: - logger.error(f"Target event time-zero ({target_zero}) not found in dataset. Cannot plot leadtime.") - return + msg = f"Target event time-zero ({target_zero}) not found in dataset. Cannot plot leadtime." + logger.error(msg) + return template.render_title_message("Error", msg), pd.DataFrame() summary_data = dataframe[dataframe[target_event] == 1] if len(summary_data.index) == 0: - logger.error(f"No positive events ({target_event}=1) were found") - return + msg = f"No positive events ({target_event}=1) were found" + logger.error(msg) + return template.render_title_message("Error", msg), pd.DataFrame() cohort_mask = summary_data[cohort_col].isin(subgroups) threshold_mask = summary_data[score] > threshold @@ -291,7 +294,7 @@ def _plot_leadtime_enc( if summary_data is not None and len(summary_data) > censor_threshold: summary_data = summary_data[[target_zero, ref_time, cohort_col]] else: - return template.render_censored_plot_message(censor_threshold) + return template.render_censored_plot_message(censor_threshold), pd.DataFrame() # filter by group size counts = summary_data[cohort_col].value_counts() @@ -303,7 +306,7 @@ def _plot_leadtime_enc( ) if len(summary_data.index) == 0: - return template.render_censored_plot_message(censor_threshold) + return template.render_censored_plot_message(censor_threshold), summary_data # Truncate to minute but plot hour summary_data[x_label] = (summary_data[ref_time] - summary_data[target_zero]).dt.total_seconds() // 60 / 60 @@ -311,11 +314,11 @@ def _plot_leadtime_enc( title = f'Lead Time from {score.replace("_", " ")} to {(target_zero).replace("_", " ")}' rows = summary_data[cohort_col].nunique() svg = plot.leadtime_violin(summary_data, x_label, cohort_col, xmax=max_hours, figsize=(9, 1 + rows)) - return template.render_title_with_image(title, svg) + return template.render_title_with_image(title, svg), summary_data @store_call_parameters(cohort_col="cohort_col", subgroups="subgroups") -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_cohort_evaluation( cohort_col: str, @@ -324,7 +327,7 @@ def plot_cohort_evaluation( score_column: str, thresholds: list[float], per_context: bool = False, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots model performance metrics split by on a cohort attribute. @@ -364,7 +367,6 @@ def plot_cohort_evaluation( ) -@disk_cached_html_segment def _plot_cohort_evaluation( dataframe: pd.DataFrame, entity_keys: list[str], @@ -378,7 +380,7 @@ def _plot_cohort_evaluation( aggregation_method: str = "max", threshold_col: str = "Threshold", ref_time: str = None, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots model performance metrics split by on a cohort attribute. @@ -441,14 +443,14 @@ def _plot_cohort_evaluation( try: assert_valid_performance_metrics_df(plot_data) except ValueError: - return template.render_censored_plot_message(censor_threshold) + return template.render_censored_plot_message(censor_threshold), plot_data svg = plot.cohort_evaluation_vs_threshold(plot_data, cohort_feature=cohort_col, highlight=thresholds) title = f"Model Performance Metrics on {cohort_col} across Thresholds" - return template.render_title_with_image(title, svg) + return template.render_title_with_image(title, svg), plot_data @store_call_parameters(cohort_dict="cohort_dict") -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_model_evaluation( cohort_dict: dict[str, tuple[Any]], @@ -456,7 +458,7 @@ def plot_model_evaluation( score_column: str, thresholds: list[float], per_context: bool = False, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Generates a 2x3 plot showing the performance of a model. @@ -503,7 +505,6 @@ def plot_model_evaluation( ) -@disk_cached_html_segment def _model_evaluation( dataframe: pd.DataFrame, entity_keys: list[str], @@ -516,7 +517,7 @@ def _model_evaluation( aggregation_method: str = "max", ref_time: Optional[str] = None, cohort: dict = {}, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ plots common model evaluation metrics @@ -563,10 +564,13 @@ def _model_evaluation( requirements = FilterRule.isin(target, (0, 1)) & FilterRule.notna(score_col) data = requirements.filter(data) if len(data.index) < censor_threshold: - return template.render_censored_plot_message(censor_threshold) + return template.render_censored_plot_message(censor_threshold), data if (lcount := data[target].nunique()) != 2: - return template.render_title_message( - "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}" + return ( + template.render_title_message( + "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}" + ), + data, ) # stats and ci handle percentile/percentage independently - evaluation wants 0-100 for displays @@ -595,7 +599,7 @@ def _model_evaluation( show_thresholds=True, highlight=thresholds, ) - return template.render_title_with_image(title, svg) + return template.render_title_with_image(title, svg), data @store_call_parameters @@ -622,7 +626,7 @@ def plot_trend_intervention_outcome() -> HTML: ) -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_intervention_outcome_timeseries( cohort_col: str, @@ -631,7 +635,7 @@ def plot_intervention_outcome_timeseries( intervention: str, reference_time_col: str, censor_threshold: int = 10, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots two timeseries based on an outcome and an intervention. @@ -666,7 +670,6 @@ def plot_intervention_outcome_timeseries( ) -@disk_cached_html_segment def _plot_trend_intervention_outcome( dataframe: pd.DataFrame, entity_keys: list[str], @@ -676,7 +679,7 @@ def _plot_trend_intervention_outcome( intervention: str, reftime: str, censor_threshold: int = 10, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots two timeseries based on selectors; an outcome and then an intervention. @@ -748,7 +751,7 @@ def _plot_trend_intervention_outcome( "Missing Outcome", f"No outcome timeseries plotted; No events with name {outcome}." ) - return HTML(outcome_plot.data + intervention_plot.data) + return HTML(outcome_plot.data + intervention_plot.data), dataframe def _plot_ts_cohort( @@ -842,11 +845,11 @@ def _plot_ts_cohort( @store_call_parameters(cohort_dict="cohort_dict") -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_model_score_comparison( cohort_dict: dict[str, tuple[Any]], target: str, scores: tuple[str], *, per_context: bool -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots a comparison of model scores for a given subpopulation. @@ -903,17 +906,17 @@ def plot_model_score_comparison( try: assert_valid_performance_metrics_df(plot_data) except ValueError: - return template.render_censored_plot_message(sg.censor_threshold) + return template.render_censored_plot_message(sg.censor_threshold), plot_data svg = plot.cohort_evaluation_vs_threshold(plot_data, cohort_feature="ScoreName") title = f"Model Metrics: {', '.join(scores)} vs {target}" - return template.render_title_with_image(title, svg) + return template.render_title_with_image(title, svg), plot_data -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_model_target_comparison( cohort_dict: dict[str, tuple[Any]], targets: tuple[str], score: str, *, per_context: bool -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Plots a comparison of model targets for a given score and subpopulation. @@ -970,15 +973,15 @@ def plot_model_target_comparison( try: assert_valid_performance_metrics_df(plot_data) except ValueError: - return template.render_censored_plot_message(sg.censor_threshold) + return template.render_censored_plot_message(sg.censor_threshold), plot_data svg = plot.cohort_evaluation_vs_threshold(plot_data, cohort_feature="ScoreName") title = f"Model Metrics: {', '.join(targets)} vs {score}" - return template.render_title_with_image(title, svg) + return template.render_title_with_image(title, svg), plot_data # region Explore Any Metric (NNT, etc) @store_call_parameters(cohort_dict="cohort_dict") -@disk_cached_html_segment +@disk_cached_html_and_df_segment @export def plot_binary_classifier_metrics( metric_generator: BinaryClassifierMetricGenerator, @@ -989,7 +992,7 @@ def plot_binary_classifier_metrics( *, per_context: bool = False, table_only: bool = False, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ Generates a plot with model metrics. @@ -1074,7 +1077,7 @@ def binary_classifier_metric_evaluation( aggregation_method: str = "max", ref_time: str = None, table_only: bool = False, -) -> HTML: +) -> tuple[HTML, pd.DataFrame]: """ plots common model evaluation metrics @@ -1118,10 +1121,13 @@ def binary_classifier_metric_evaluation( requirements = FilterRule.isin(target, (0, 1)) & FilterRule.notna(score_col) data = requirements.filter(data) if len(data.index) < censor_threshold: - return template.render_censored_plot_message(censor_threshold) + return template.render_censored_plot_message(censor_threshold), data if (lcount := data[target].nunique()) != 2: - return template.render_title_message( - "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}" + return ( + template.render_title_message( + "Evaluation Error", f"Model Evaluation requires exactly two classes but found {lcount}" + ), + data, ) if isinstance(metrics, str): metrics = [metrics] @@ -1134,8 +1140,8 @@ def binary_classifier_metric_evaluation( if log_all: recorder.populate_metrics(attributes=attributes, metrics={metric: stats[metric].to_dict()}) if table_only: - return HTML(stats[metrics].T.to_html()) - return plot.binary_classifier.plot_metric_list(stats, metrics) + return HTML(stats[metrics].T.to_html()), data + return plot.binary_classifier.plot_metric_list(stats, metrics), data # endregion diff --git a/src/seismometer/controls/decorators.py b/src/seismometer/controls/decorators.py index eaac8cea..0dfaa13e 100644 --- a/src/seismometer/controls/decorators.py +++ b/src/seismometer/controls/decorators.py @@ -5,6 +5,7 @@ import hashlib import logging import os +import pickle import shutil from functools import wraps from inspect import signature @@ -27,4 +28,29 @@ def html_save(html, filepath) -> None: filepath.write_text(html.data) +def html_and_df_save(data, filepath) -> None: + """ + Saves a tuple of (HTML, pd.DataFrame) to disk. + """ + html, df = data + html_path = filepath.with_suffix(".html") + df_path = filepath.with_suffix(".pkl") + html_path.write_text(html.data) + with open(df_path, "wb") as f: + pickle.dump(df, f) + + +def html_and_df_load(filepath) -> tuple[HTML, Any]: + """ + Loads a tuple of (HTML, pd.DataFrame) from disk. + """ + html_path = filepath.with_suffix(".html") + df_path = filepath.with_suffix(".pkl") + html = HTML(html_path.read_text()) + with open(df_path, "rb") as f: + df = pickle.load(f) + return html, df + + disk_cached_html_segment = DiskCachedFunction("html", save_fn=html_save, load_fn=html_load, return_type=HTML) +disk_cached_html_and_df_segment = DiskCachedFunction("html_and_df", save_fn=html_and_df_save, load_fn=html_and_df_load) diff --git a/src/seismometer/controls/explore.py b/src/seismometer/controls/explore.py index 09d4f020..ae0fa316 100644 --- a/src/seismometer/controls/explore.py +++ b/src/seismometer/controls/explore.py @@ -35,6 +35,7 @@ class UpdatePlotWidget(Box): UPDATE_PLOTS = "Update" UPDATING_PLOTS = "Updating ..." SHOW_CODE = "Show code?" + SHOW_DATA = "Show raw data?" def __init__(self): self.code_checkbox = Checkbox( @@ -45,10 +46,18 @@ def __init__(self): tooltip="Show the code used to generate the plot.", layout=Layout(margin="var(--jp-widgets-margin) var(--jp-widgets-margin) var(--jp-widgets-margin) 10px;"), ) + self.data_checkbox = Checkbox( + value=False, + description=self.SHOW_DATA, + disabled=False, + indent=False, + tooltip="Show the raw data used to generate the plot.", + layout=Layout(margin="var(--jp-widgets-margin) var(--jp-widgets-margin) var(--jp-widgets-margin) 10px;"), + ) self.plot_button = Button(description=self.UPDATE_PLOTS, button_style="primary") layout = Layout(align_items="flex-start") - children = [self.plot_button, self.code_checkbox] + children = [self.plot_button, self.code_checkbox, self.data_checkbox] super().__init__(layout=layout, children=children) @property @@ -59,6 +68,14 @@ def show_code(self) -> bool: def show_code(self, show_code: bool) -> bool: self.code_checkbox.value = show_code + @property + def show_data(self) -> bool: + return self.data_checkbox.value + + @show_data.setter + def show_data(self, show_data: bool) -> bool: + self.data_checkbox.value = show_data + @property def disabled(self) -> bool: return self.plot_button.disabled @@ -86,6 +103,13 @@ def callback_wrapper(change): self.code_checkbox.observe(callback_wrapper, "value") + def on_toggle_data(self, callback): + @wraps(callback) + def callback_wrapper(change): + callback(self.data_checkbox.value) + + self.data_checkbox.observe(callback_wrapper, "value") + class ModelOptionsWidget(VBox, ValueWidget): value = traitlets.Dict(help="The selected values for the model options.") @@ -874,11 +898,20 @@ def __init__( title = HTML(value=f"""

{title}

""") self.center = Output(layout=Layout(height="max-content", max_width="1200px")) self.code_output = Output(layout=Layout(height="max-content", max_width="1200px")) + self.data_output = Output(layout=Layout(height="max-content", max_width="1200px")) self.option_widget = option_widget self.plot_function = plot_function self.update_plot_widget = UpdatePlotWidget() super().__init__( - children=[title, self.option_widget, self.update_plot_widget, self.code_output, self.center], layout=layout + children=[ + title, + self.option_widget, + self.update_plot_widget, + self.code_output, + self.data_output, + self.center, + ], + layout=layout, ) # show initial plot @@ -886,11 +919,13 @@ def __init__( self.update_plot(initial=True) else: self.current_plot_code = self.NO_CODE_STRING + self.current_plot_data = None # attache event handlers self.option_widget.observe(self._on_option_change, "value") self.update_plot_widget.on_click(self._on_plot_button_click) self.update_plot_widget.on_toggle_code(self._on_toggle_code) + self.update_plot_widget.on_toggle_data(self._on_toggle_data) @property def disabled(self) -> bool: @@ -910,6 +945,17 @@ def show_code(self) -> bool: def show_code(self, show_code: bool): self.update_plot_widget.show_code = show_code + @property + def show_data(self) -> bool: + """ + If the widget should show the plot's data. + """ + return self.update_plot_widget.show_data + + @show_data.setter + def show_data(self, show_data: bool): + self.update_plot_widget.show_data = show_data + @staticmethod def _is_interactive_notebook() -> bool: ip = get_ipython() @@ -937,10 +983,12 @@ def _try_generate_plot(self) -> Any: try: plot_args, plot_kwargs = self.generate_plot_args() self.current_plot_code = self.generate_plot_code(plot_args, plot_kwargs) - plot = self.plot_function(*plot_args, **plot_kwargs) + # The plot function will now return a tuple (plot, data) + plot, self.current_plot_data = self.plot_function(*plot_args, **plot_kwargs) except Exception as e: import traceback + self.current_plot_data = None plot = HTML(f"

Exception:
{e}

{traceback.format_exc()}
") return plot @@ -949,6 +997,7 @@ def _on_plot_button_click(self, button=None): self.option_widget.disabled = True self.update_plot() self._on_toggle_code(self.show_code) + self._on_toggle_data(self.show_data) self.option_widget.disabled = False def generate_plot_code(self, plot_args: tuple = None, plot_kwargs: dict = None) -> str: @@ -987,6 +1036,16 @@ def _on_toggle_code(self, show_code: bool): with self.code_output: display(highlighted_code) + def _on_toggle_data(self, show_data: bool): + """Handle for the toggle data checkbox.""" + self.data_output.clear_output() + if not show_data: + return + + if self.current_plot_data is not None: + with self.data_output: + display(self.current_plot_data) + def _on_option_change(self, change=None): """Enable the plot to be updated.""" self.update_plot_widget.disabled = self.disabled diff --git a/tests/api/test_api_explore.py b/tests/api/test_api_explore.py index 99ecab62..94efdd0e 100644 --- a/tests/api/test_api_explore.py +++ b/tests/api/test_api_explore.py @@ -148,7 +148,7 @@ def test_cohort_list_details_summary_generated(self, fake_seismo): mock_filter.assert_called_once() mock_render.assert_called_once() - assert "mock summary" in result.data + assert "mock summary" in result[0].data def test_cohort_list_details_censored_output(self, fake_seismo): fake_seismo.config.censor_min_count = 10 @@ -161,7 +161,7 @@ def test_cohort_list_details_censored_output(self, fake_seismo): rule.filter.return_value = fake_seismo.dataframe.iloc[:0] # no rows mock_filter.return_value = rule - result = cohort_list_details({"Cohort": ["C1", "C2"]}) + result, _ = cohort_list_details({"Cohort": ["C1", "C2"]}) assert "censored" in result.data.lower() mock_render.assert_called_once() @@ -175,7 +175,7 @@ def test_cohort_list_details_single_target_index_rename(self, fake_seismo): rule.filter.return_value = fake_seismo.dataframe mock_filter.return_value = rule - result = cohort_list_details({"Cohort": ["C1", "C2"]}) + result, _ = cohort_list_details({"Cohort": ["C1", "C2"]}) assert "summary" in result.data @@ -250,15 +250,15 @@ def test_plot_binary_classifier_metrics_basic(self, mock_calc, mock_plot, fake_s mock_calc.assert_called_once() mock_plot.assert_called_once() - assert isinstance(result, SVG) - assert "http://www.w3.org/2000/svg" in result.data + assert isinstance(result[0], SVG) + assert "http://www.w3.org/2000/svg" in result[0].data @patch("seismometer.data.performance.BinaryClassifierMetricGenerator.calculate_binary_stats") def test_plot_binary_classifier_metrics_table_only(self, mock_calc, fake_seismo): mock_stats = pd.DataFrame({"Sensitivity": [0.88]}, index=["value"]) mock_calc.return_value = (mock_stats, None) - result = plot_binary_classifier_metrics( + result, _ = plot_binary_classifier_metrics( metric_generator=BinaryClassifierMetricGenerator(rho=0.5), metrics="Sensitivity", cohort_dict={}, @@ -284,7 +284,7 @@ def test_plot_binary_classifier_metrics_nonbinary_target( with patch("seismometer.seismogram.Seismogram", return_value=fake_seismo): fake_seismo.dataframe = df - result = plot_binary_classifier_metrics( + result, _ = plot_binary_classifier_metrics( metric_generator=BinaryClassifierMetricGenerator(rho=0.5), metrics=["Accuracy"], cohort_dict={"Cohort": ("C1",)}, @@ -304,7 +304,7 @@ def test_model_evaluation_not_binary_target(self, mock_filter, mock_render, fake df["event1_Value"] = 1 # single class mock_filter.return_value = df - result = _model_evaluation( + result, _ = _model_evaluation( df, entity_keys=["entity"], target_event="event1", @@ -341,7 +341,7 @@ def test_model_evaluation_valid_path(self, mock_filter, mock_render, mock_plot, mock_filter.return_value = df mock_scores.return_value = df - result = _model_evaluation( + result, _ = _model_evaluation( df, entity_keys=["entity"], target_event="event1", @@ -388,7 +388,7 @@ def test_plot_model_score_comparison( mock_event_score.return_value = df # per_context=False path mock_perf_data.return_value = _mock_perf_df - result = plot_model_score_comparison( + result, _ = plot_model_score_comparison( cohort_dict={"Cohort": ("C1", "C2")}, target="event1", scores=("score1",), @@ -431,7 +431,7 @@ def test_plot_model_target_comparison( mock_event_score.return_value = df mock_perf_data.return_value = _mock_perf_df - result = plot_model_target_comparison( + result, _ = plot_model_target_comparison( cohort_dict={"Cohort": ("C1", "C2")}, targets=("event1",), score="score1", @@ -461,7 +461,7 @@ def test_plot_cohort_evaluation_censored_data( mock_get_scores.return_value = df mock_get_perf.return_value = _mock_perf_df - result = _plot_cohort_evaluation( + result, _ = _plot_cohort_evaluation( dataframe=_mock_perf_df, entity_keys=["entity"], target="event1_Value", @@ -494,7 +494,7 @@ def test_plot_cohort_evaluation_success( mock_get_scores.return_value = df mock_get_perf.return_value = _mock_perf_df - result = _plot_cohort_evaluation( + result, _ = _plot_cohort_evaluation( dataframe=_mock_perf_df, entity_keys=["entity"], target="event1_Value", @@ -524,7 +524,7 @@ def test_plot_cohort_group_histograms( mock_seismo.return_value = fake_seismo mock_filter.return_value = fake_seismo.dataframe - result = plot_cohort_group_histograms( + result, _ = plot_cohort_group_histograms( cohort_col="Cohort", subgroups=["C1", "C2"], target_column="event1", @@ -540,7 +540,7 @@ def test_plot_cohort_hist_censored_after_filter(self, mock_render, fake_seismo): # Simulate filtered-out result empty_df = fake_seismo.dataframe.iloc[0:0].copy() - result = _plot_cohort_hist( + result, _ = _plot_cohort_hist( dataframe=empty_df, target="event1_Value", output="score1", @@ -560,7 +560,7 @@ def test_plot_cohort_hist_plot_fails(self, mock_get_cohort_data, mock_plot, mock mock_df = pd.DataFrame({"cohort": ["C1"] * 11 + ["C2"] * 11, "pred": [0.1] * 11 + [0.2] * 11}) mock_get_cohort_data.return_value = mock_df - result = _plot_cohort_hist( + result, _ = _plot_cohort_hist( dataframe=fake_seismo.dataframe, target="event1_Value", output="score1", @@ -610,7 +610,7 @@ def test_plot_cohort_lead_time(self, mock_plot, mock_seismo, fake_seismo): def test_leadtime_enc_missing_target_column(self, fake_seismo, caplog): df = fake_seismo.dataframe.drop(columns=["event1_Value"]) with caplog.at_level("ERROR"): - result = _plot_leadtime_enc( + result, _ = _plot_leadtime_enc( df, entity_keys=["entity"], target_event="event1_Value", @@ -623,13 +623,13 @@ def test_leadtime_enc_missing_target_column(self, fake_seismo, caplog): max_hours=48, x_label="Lead Time (hours)", ) - assert result is None + assert "Error" in result.data assert "Target event (event1_Value) not found" in caplog.text def test_leadtime_enc_missing_target_zero_column(self, fake_seismo, caplog): df = fake_seismo.dataframe.drop(columns=["event1_Time"]) with caplog.at_level("ERROR"): - result = _plot_leadtime_enc( + result, _ = _plot_leadtime_enc( df, entity_keys=["entity"], target_event="event1_Value", @@ -642,14 +642,14 @@ def test_leadtime_enc_missing_target_zero_column(self, fake_seismo, caplog): max_hours=48, x_label="Lead Time (hours)", ) - assert result is None + assert "Error" in result.data assert "Target event time-zero (event1_Time) not found" in caplog.text def test_leadtime_enc_no_positive_events(self, fake_seismo, caplog): df = fake_seismo.dataframe.copy() df["event1_Value"] = 0 # force all negative with caplog.at_level("ERROR"): - result = _plot_leadtime_enc( + result, _ = _plot_leadtime_enc( df, entity_keys=["entity"], target_event="event1_Value", @@ -662,7 +662,7 @@ def test_leadtime_enc_no_positive_events(self, fake_seismo, caplog): max_hours=48, x_label="Lead Time (hours)", ) - assert result is None + assert "Error" in result.data assert "No positive events (event1_Value=1) were found" in caplog.text @patch("seismometer.api.plots.pdh.event_score") @@ -673,7 +673,7 @@ def test_leadtime_enc_subgroup_filter_excludes_all(self, mock_render, mock_filte mock_filter.return_value = df[:0] mock_score.return_value = df[:0] - result = _plot_leadtime_enc( + result, _ = _plot_leadtime_enc( df, entity_keys=["entity"], target_event="event1_Value", @@ -714,7 +714,7 @@ def test_plot_trend_intervention_outcome_combines_both( self, mock_render, mock_plot, mock_event_value, fake_seismo ): fake_seismo.selected_cohort = ("Cohort", ["C1", "C2"]) - result = plot_trend_intervention_outcome() + result, _ = plot_trend_intervention_outcome() assert isinstance(result, HTML) assert "Outcome" in result.data assert "Intervention" in result.data @@ -728,7 +728,7 @@ def test_plot_trend_intervention_outcome_missing_intervention( self, mock_msg, mock_render, mock_plot, mock_event_value, fake_seismo ): fake_seismo.selected_cohort = ("Cohort", ["C1", "C2"]) - result = plot_trend_intervention_outcome() + result, _ = plot_trend_intervention_outcome() assert isinstance(result, HTML) assert "Missing Intervention" in result.data assert "Outcome" in result.data @@ -741,7 +741,7 @@ def test_plot_trend_intervention_outcome_missing_outcome( self, mock_msg, mock_render, mock_plot, mock_event_value, fake_seismo ): fake_seismo.selected_cohort = ("Cohort", ["C1", "C2"]) - result = plot_trend_intervention_outcome() + result, _ = plot_trend_intervention_outcome() assert isinstance(result, HTML) assert "Missing Outcome" in result.data assert "Intervention" in result.data diff --git a/tests/controls/test_explore.py b/tests/controls/test_explore.py index a5bb0cb5..5a1d6583 100644 --- a/tests/controls/test_explore.py +++ b/tests/controls/test_explore.py @@ -29,6 +29,8 @@ def test_init(self): assert widget.plot_button.disabled is False assert widget.code_checkbox.description == widget.SHOW_CODE assert widget.show_code is False + assert widget.data_checkbox.description == widget.SHOW_DATA + assert widget.show_data is False def test_plot_button_click(self): count = 0 @@ -68,15 +70,28 @@ def on_toggle_callback(button): widget.code_checkbox.value = True assert count == 1 + def test_toggle_data_checkbox(self): + count = 0 + + def on_toggle_callback(button): + nonlocal count + count += 1 + + widget = undertest.UpdatePlotWidget() + widget.on_toggle_data(on_toggle_callback) + widget.data_checkbox.value = True + assert count == 1 + class TestExplorationBaseClass: def test_base_class(self, caplog): option_widget = ipywidgets.Checkbox(description="ClickMe") - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) widget = undertest.ExplorationWidget("ExploreTest", option_widget, plot_function) plot_function.assert_not_called() - assert "Subclasses must implement this method" in widget._try_generate_plot().value + plot = widget._try_generate_plot() + assert "Subclasses must implement this method" in plot.value @pytest.mark.parametrize( "plot_module,plot_code", @@ -88,7 +103,7 @@ def test_base_class(self, caplog): ) def test_args_subclass(self, plot_module, plot_code): option_widget = ipywidgets.Checkbox(description="ClickMe") - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_something" plot_function.__module__ = plot_module @@ -103,11 +118,13 @@ def generate_plot_args(self) -> tuple[tuple, dict]: plot_function.assert_called_once_with(False) assert widget.current_plot_code == plot_code assert widget.show_code is False + assert widget.show_data is False assert widget._try_generate_plot() == "some result" + assert widget.current_plot_data == "some data" def test_kwargs_subclass(self): option_widget = ipywidgets.Checkbox(description="ClickMe") - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_something" plot_function.__module__ = "test_explore" @@ -122,11 +139,13 @@ def generate_plot_args(self) -> tuple[tuple, dict]: plot_function.assert_called_once_with(checkbox=False) assert widget.current_plot_code == "test_explore.plot_something(checkbox=False)" assert widget.show_code is False + assert widget.show_data is False assert widget._try_generate_plot() == "some result" + assert widget.current_plot_data == "some data" def test_args_kwargs_subclass(self): option_widget = ipywidgets.Checkbox(description="ClickMe") - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_something" plot_function.__module__ = "test_explore" @@ -141,7 +160,9 @@ def generate_plot_args(self) -> tuple[tuple, dict]: plot_function.assert_called_once_with("test", checkbox=False) assert widget.current_plot_code == "test_explore.plot_something('test', checkbox=False)" assert widget.show_code is False + assert widget.show_data is False assert widget._try_generate_plot() == "some result" + assert widget.current_plot_data == "some data" def test_exception_plot_code_subclass(self): option_widget = ipywidgets.Checkbox(description="ClickMe") @@ -160,10 +181,11 @@ def generate_plot_args(self) -> tuple[tuple, dict]: plot = widget._try_generate_plot() assert "Traceback" in plot.value assert "Test Exception" in plot.value + assert widget.current_plot_data is None def test_no_initial_plot_subclass(self): option_widget = ipywidgets.Checkbox(description="ClickMe") - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_something" plot_function.__module__ = "test_explore" @@ -178,13 +200,16 @@ def generate_plot_args(self) -> tuple[tuple, dict]: widget.center.outputs == [] plot_function.assert_not_called() assert widget.current_plot_code == ExploreFake.NO_CODE_STRING + assert widget.current_plot_data is None assert widget.show_code is False - widget._try_generate_plot() == "" + assert widget.show_data is False + assert widget._try_generate_plot() == "some result" + assert widget.current_plot_data == "some data" @pytest.mark.parametrize("show_code", [True, False]) def test_toggle_code_callback(self, show_code, capsys): option_widget = ipywidgets.Checkbox(description="ClickMe") - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_something" plot_function.__module__ = "test_explore" @@ -204,6 +229,29 @@ def generate_plot_args(self) -> tuple[tuple, dict]: code_in_output = "test_explore.plot_something" in stdout.split("\n")[-2] assert code_in_output == show_code + @pytest.mark.parametrize("show_data", [True, False]) + def test_toggle_data_callback(self, show_data, capsys): + option_widget = ipywidgets.Checkbox(description="ClickMe") + plot_function = Mock(return_value=("some result", "some data")) + plot_function.__name__ = "plot_something" + plot_function.__module__ = "test_explore" + + class ExploreFake(undertest.ExplorationWidget): + def __init__(self): + super().__init__("Fake Explorer", option_widget, plot_function) + + def generate_plot_args(self) -> tuple[tuple, dict]: + return ["test"], {"checkbox": self.option_widget.value} + + widget = ExploreFake() + widget.show_data = show_data + widget.data_output = MagicMock() + widget._on_plot_button_click() + stdout = capsys.readouterr().out + assert "some result" in stdout + data_in_output = "some data" in stdout.split("\n")[-2] + assert data_in_output == show_data + # endregion # region Test Model Options Widgets @@ -559,7 +607,7 @@ class TestExplorationSubpopulationWidget: def test_init(self, mock_seismo): fake_seismo = mock_seismo() fake_seismo.available_cohort_groups = {"C1": ["C1.1", "C1.2"], "C2": ["C2.1", "C2.2"]} - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -569,12 +617,13 @@ def test_init(self, mock_seismo): assert widget.update_plot_widget.disabled assert widget.current_plot_code == "test_explore.plot_function({})" plot_function.assert_called_once_with({}) # default value + assert widget.current_plot_data == "some data" @patch.object(seismogram, "Seismogram", return_value=Mock()) def test_option_update(self, mock_seismo): fake_seismo = mock_seismo() fake_seismo.available_cohort_groups = {"C1": ["C1.1", "C1.2"], "C2": ["C2.1", "C2.2"]} - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -593,6 +642,7 @@ def test_option_update(self, mock_seismo): ] } ) # updated value + assert widget.current_plot_data == "some data" class TestExplorationModelSubgroupEvaluationWidget: @@ -604,7 +654,7 @@ def test_init(self, mock_seismo): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -616,6 +666,7 @@ def test_init(self, mock_seismo): assert widget.update_plot_widget.disabled assert widget.current_plot_code == "test_explore.plot_function({}, 'T1', 'S1', [0.2, 0.1], per_context=False)" plot_function.assert_called_once_with({}, "T1", "S1", [0.2, 0.1], per_context=False) # default value + assert widget.current_plot_data == "some data" @patch.object(seismogram, "Seismogram", return_value=Mock()) def test_option_update(self, mock_seismo): @@ -625,7 +676,7 @@ def test_option_update(self, mock_seismo): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -636,6 +687,7 @@ def test_option_update(self, mock_seismo): widget.option_widget.cohort_list.value = {"C2": ("C2.1",)} widget.update_plot() plot_function.assert_called_with({"C2": ("C2.1",)}, "T1", "S1", [0.2, 0.1], per_context=False) # updated value + assert widget.current_plot_data == "some data" class TestExplorationCohortSubclassEvaluationWidget: @@ -651,7 +703,7 @@ def test_init(self, mock_seismo, threshold_handling, thresholds): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -673,6 +725,7 @@ def test_init(self, mock_seismo, threshold_handling, thresholds): "C1", ("C1.1", "C1.2"), "T1", "S1", per_context=False ) # default value assert widget.current_plot_code == expected_code + assert widget.current_plot_data == "some data" class TestExplorationCohortOutcomeInterventionEvaluationWidget: @@ -688,7 +741,7 @@ def test_init(self, mock_seismo): ) fake_seismo.config = fake_config - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -702,6 +755,7 @@ def test_init(self, mock_seismo): widget.current_plot_code == "test_explore.plot_function('C1', ('C1.1', 'C1.2'), 'O1', 'I1', 'pred_time')" ) plot_function.assert_called_once_with("C1", ("C1.1", "C1.2"), "O1", "I1", "pred_time") # default value + assert widget.current_plot_data == "some data" class TestExplorationScoreComparisonByCohortWidget: @@ -713,7 +767,7 @@ def test_init(self, mock_seismo): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -725,6 +779,7 @@ def test_init(self, mock_seismo): assert widget.update_plot_widget.disabled assert widget.current_plot_code == "test_explore.plot_function({}, 'T1', ('S1', 'S2'), per_context=False)" plot_function.assert_called_once_with({}, "T1", ("S1", "S2"), per_context=False) # default value + assert widget.current_plot_data == "some data" class TestExplorationTargetComparisonByCohortWidget: @@ -736,7 +791,7 @@ def test_init(self, mock_seismo): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -748,6 +803,7 @@ def test_init(self, mock_seismo): assert widget.update_plot_widget.disabled assert widget.current_plot_code == "test_explore.plot_function({}, ('T1', 'T2'), 'S1', per_context=False)" plot_function.assert_called_once_with({}, ("T1", "T2"), "S1", per_context=False) # default value + assert widget.current_plot_data == "some data" # endregion @@ -804,7 +860,7 @@ def test_init(self, mock_seismo): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -823,6 +879,7 @@ def test_init(self, mock_seismo): plot_function.assert_called_once_with( metric_generator, ("Precision", "Recall"), {}, "T1", "S1", per_context=False ) + assert widget.current_plot_data == "some data" @patch.object(seismogram, "Seismogram", return_value=Mock()) def test_option_update(self, mock_seismo): @@ -833,7 +890,7 @@ def test_option_update(self, mock_seismo): fake_seismo.target_cols = ["T1", "T2"] fake_seismo.output_list = ["S1", "S2"] - plot_function = Mock(return_value="some result") + plot_function = Mock(return_value=("some result", "some data")) plot_function.__name__ = "plot_function" plot_function.__module__ = "test_explore" @@ -853,6 +910,7 @@ def test_option_update(self, mock_seismo): + "('F1',), {}, 'T2', 'S2', per_context=False)" ) plot_function.assert_called_with(metric_generator, ("F1",), {}, "T2", "S2", per_context=False) + assert widget.current_plot_data == "some data" class TestExploreBinaryModelAnalytics: