Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/168.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a 'Show raw data' checkbox to the interactive `Explore...` widgets to display the underlying pandas.DataFrame used for visualizations.
15 changes: 8 additions & 7 deletions src/seismometer/api/explore.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Optional

import pandas as pd
from IPython.display import HTML, display

from seismometer.controls.decorators import disk_cached_html_segment
from seismometer.controls.decorators import disk_cached_html_and_df_segment
from seismometer.controls.explore import ExplorationWidget # noqa:
from seismometer.controls.explore import (
ExplorationCohortOutcomeInterventionEvaluationWidget,
Expand Down Expand Up @@ -216,9 +217,9 @@ def on_widget_value_changed(*args):
return VBox(children=[comparison_selections, output], layout=BOX_GRID_LAYOUT)


@disk_cached_html_segment
@disk_cached_html_and_df_segment
@export
def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> tuple[HTML, pd.DataFrame]:
"""
Generates an HTML table of cohort details.

Expand All @@ -229,8 +230,8 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:

Returns
-------
HTML
able indexed by targets, with counts of unique entities, and mean values of the output columns.
tuple[HTML, pd.DataFrame]
able indexed by targets, with counts of unique entities, and mean values of the output columns, and the data
"""
from seismometer.data.filter import filter_rule_from_cohort_dictionary

Expand All @@ -246,7 +247,7 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
]
cohort_count = data[sg.entity_keys[0]].nunique()
if cohort_count < sg.censor_threshold:
return template.render_censored_plot_message(sg.censor_threshold)
return template.render_censored_plot_message(sg.censor_threshold), data

groups = data.groupby(target_cols)
float_cols = list(data[intervention_cols + outcome_cols].select_dtypes(include=float))
Expand All @@ -268,7 +269,7 @@ def cohort_list_details(cohort_dict: dict[str, tuple[Any]]) -> HTML:
groupstats.index.rename(new_names, inplace=True)
html_table = groupstats.to_html()
title = "Summary"
return template.render_title_message(title, html_table)
return template.render_title_message(title, html_table), data


# endregion
Loading
Loading