From b9a841d4e9b3f4a8a479aafaaaba978c9589cf67 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 14 Dec 2025 17:01:30 +0200 Subject: [PATCH 1/8] feat: close #236 --- src/rtichoke/__init__.py | 5 ++- src/rtichoke/utility/decision.py | 71 ++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index c60435b..b95277a 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -24,7 +24,10 @@ # create_calibration_curve as create_calibration_curve, # ) -from rtichoke.utility.decision import create_decision_curve as create_decision_curve +from rtichoke.utility.decision import ( + create_decision_curve as create_decision_curve, + create_decision_curve_times as create_decision_curve_times, +) from rtichoke.utility.decision import plot_decision_curve as plot_decision_curve from rtichoke.performance_data.performance_data import ( diff --git a/src/rtichoke/utility/decision.py b/src/rtichoke/utility/decision.py index ca2b71c..09b75c2 100644 --- a/src/rtichoke/utility/decision.py +++ b/src/rtichoke/utility/decision.py @@ -6,7 +6,12 @@ from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( _create_rtichoke_plotly_curve_binary, + _create_plotly_curve_times, _plot_rtichoke_curve_binary, + _create_rtichoke_curve_list_times, +) +from rtichoke.performance_data.performance_data_times import ( + prepare_performance_data_times, ) import numpy as np import polars as pl @@ -162,3 +167,69 @@ def plot_decision_curve( max_p_threshold=max_p_threshold, ) return fig + + +def create_decision_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + decision_type: str = "conventional", + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + min_p_threshold: float = 0, + max_p_threshold: float = 1, + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Decision Curve.""" + + if decision_type == "conventional": + curve = "decision" + else: + curve = "interventions avoided" + + performance_data_times = prepare_performance_data_times( + probs, + reals, + times, + by=by, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + stratified_by=["probability_threshold"], + ) + + rtichoke_curve_list_times = _create_rtichoke_curve_list_times( + performance_data_times, stratified_by="probability_threshold", curve=curve + ) + + fig = _create_plotly_curve_times(rtichoke_curve_list_times) + + return fig From 738c4d3ed3ffeb5b5bfb9ddd6e4363f7720793a9 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 14 Dec 2025 18:37:46 +0200 Subject: [PATCH 2/8] feat: close #238 --- src/rtichoke/__init__.py | 16 +++++- src/rtichoke/discrimination/gains.py | 56 +++++++++++++++++++ src/rtichoke/discrimination/lift.py | 56 +++++++++++++++++++ .../discrimination/precision_recall.py | 56 +++++++++++++++++++ src/rtichoke/discrimination/roc.py | 56 +++++++++++++++++++ .../helpers/plotly_helper_functions.py | 41 ++++++++++++++ src/rtichoke/utility/decision.py | 23 +++----- 7 files changed, 287 insertions(+), 17 deletions(-) diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py index b95277a..cc5b47a 100644 --- a/src/rtichoke/__init__.py +++ b/src/rtichoke/__init__.py @@ -4,20 +4,30 @@ __version__ = version("rtichoke") -from rtichoke.discrimination.roc import create_roc_curve as create_roc_curve +from rtichoke.discrimination.roc import ( + create_roc_curve as create_roc_curve, + create_roc_curve_times as create_roc_curve_times, +) from rtichoke.discrimination.roc import plot_roc_curve as plot_roc_curve -from rtichoke.discrimination.lift import create_lift_curve as create_lift_curve +from rtichoke.discrimination.lift import ( + create_lift_curve as create_lift_curve, + create_lift_curve_times as create_lift_curve_times, +) from rtichoke.discrimination.lift import plot_lift_curve as plot_lift_curve from rtichoke.discrimination.precision_recall import ( create_precision_recall_curve as create_precision_recall_curve, + create_precision_recall_curve_times as create_precision_recall_curve_times, ) from rtichoke.discrimination.precision_recall import ( plot_precision_recall_curve as plot_precision_recall_curve, ) -from rtichoke.discrimination.gains import create_gains_curve as create_gains_curve +from rtichoke.discrimination.gains import ( + create_gains_curve as create_gains_curve, + create_gains_curve_times as create_gains_curve_times, +) from rtichoke.discrimination.gains import plot_gains_curve as plot_gains_curve # from rtichoke.calibration.calibration import ( diff --git a/src/rtichoke/discrimination/gains.py b/src/rtichoke/discrimination/gains.py index 39334ef..59366f4 100644 --- a/src/rtichoke/discrimination/gains.py +++ b/src/rtichoke/discrimination/gains.py @@ -5,6 +5,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -123,3 +124,58 @@ def plot_gains_curve( curve="gains", ) return fig + + +def create_gains_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="gains", + ) + + return fig diff --git a/src/rtichoke/discrimination/lift.py b/src/rtichoke/discrimination/lift.py index 65f2553..5f358af 100644 --- a/src/rtichoke/discrimination/lift.py +++ b/src/rtichoke/discrimination/lift.py @@ -5,6 +5,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -123,3 +124,58 @@ def plot_lift_curve( curve="lift", ) return fig + + +def create_lift_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="lift", + ) + + return fig diff --git a/src/rtichoke/discrimination/precision_recall.py b/src/rtichoke/discrimination/precision_recall.py index 06314d7..565cf5c 100644 --- a/src/rtichoke/discrimination/precision_recall.py +++ b/src/rtichoke/discrimination/precision_recall.py @@ -5,6 +5,7 @@ from typing import Dict, List, Sequence, Union from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -123,3 +124,58 @@ def plot_precision_recall_curve( curve="precision recall", ) return fig + + +def create_precision_recall_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="precision recall", + ) + + return fig diff --git a/src/rtichoke/discrimination/roc.py b/src/rtichoke/discrimination/roc.py index ae2d7e4..d8a8ed0 100644 --- a/src/rtichoke/discrimination/roc.py +++ b/src/rtichoke/discrimination/roc.py @@ -5,6 +5,7 @@ from typing import Dict, List, Union, Sequence from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( + _create_rtichoke_plotly_curve_times, _create_rtichoke_plotly_curve_binary, _plot_rtichoke_curve_binary, ) @@ -124,3 +125,58 @@ def plot_roc_curve( ) return fig + + +def create_roc_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values: List[str] = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#07004D", + "#E6AB02", + "#FE5F55", + "#54494B", + "#006E90", + "#BC96E6", + "#52050A", + "#1F271B", + "#BE7C4D", + "#63768D", + "#08A045", + "#320A28", + "#82FF9E", + "#2176FF", + "#D1603D", + "#585123", + ], +) -> Figure: + """Create time-dependent Lift Curve.""" + + fig = _create_rtichoke_plotly_curve_times( + probs, + reals, + times, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve="roc", + ) + + return fig diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 38d1885..5189b33 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -8,6 +8,9 @@ from typing import Any, Dict, Union, Sequence import numpy as np from rtichoke.performance_data.performance_data import prepare_performance_data +from rtichoke.performance_data.performance_data_times import ( + prepare_performance_data_times, +) _HOVER_LABELS = { "false_positive_rate": "1 - Specificity (FPR)", @@ -51,6 +54,44 @@ def _create_rtichoke_plotly_curve_binary( return fig +def _create_rtichoke_plotly_curve_times( + probs: Dict[str, np.ndarray], + reals: Union[np.ndarray, Dict[str, np.ndarray]], + times: Union[np.ndarray, Dict[str, np.ndarray]], + fixed_time_horizons: list[float], + heuristics_sets: list[Dict] = [ + { + "censoring_heuristic": "adjusted", + "competing_heuristic": "adjusted_as_negative", + } + ], + min_p_threshold: float = 0, + max_p_threshold: float = 1, + by: float = 0.01, + stratified_by: Sequence[str] = ["probability_threshold"], + size: int = 600, + color_values=None, + curve: str = "roc", +) -> go.Figure: + performance_data = prepare_performance_data_times( + probs, + reals, + times, + by=by, + fixed_time_horizons=fixed_time_horizons, + heuristics_sets=heuristics_sets, + stratified_by=stratified_by, + ) + + rtichoke_curve_list_times = _create_rtichoke_curve_list_times( + performance_data, stratified_by=stratified_by[0], curve=curve + ) + + fig = _create_plotly_curve_times(rtichoke_curve_list_times) + + return fig + + def _plot_rtichoke_curve_binary( performance_data: pl.DataFrame, stratified_by: str = "probability_threshold", diff --git a/src/rtichoke/utility/decision.py b/src/rtichoke/utility/decision.py index 09b75c2..50f0e6d 100644 --- a/src/rtichoke/utility/decision.py +++ b/src/rtichoke/utility/decision.py @@ -6,12 +6,8 @@ from plotly.graph_objs._figure import Figure from rtichoke.helpers.plotly_helper_functions import ( _create_rtichoke_plotly_curve_binary, - _create_plotly_curve_times, + _create_rtichoke_plotly_curve_times, _plot_rtichoke_curve_binary, - _create_rtichoke_curve_list_times, -) -from rtichoke.performance_data.performance_data_times import ( - prepare_performance_data_times, ) import numpy as np import polars as pl @@ -216,20 +212,19 @@ def create_decision_curve_times( else: curve = "interventions avoided" - performance_data_times = prepare_performance_data_times( + fig = _create_rtichoke_plotly_curve_times( probs, reals, times, - by=by, fixed_time_horizons=fixed_time_horizons, heuristics_sets=heuristics_sets, - stratified_by=["probability_threshold"], - ) - - rtichoke_curve_list_times = _create_rtichoke_curve_list_times( - performance_data_times, stratified_by="probability_threshold", curve=curve + by=by, + stratified_by=stratified_by, + size=size, + color_values=color_values, + curve=curve, + min_p_threshold=min_p_threshold, + max_p_threshold=max_p_threshold, ) - fig = _create_plotly_curve_times(rtichoke_curve_list_times) - return fig From 493d227b73ec9d4554764566599170334297b314 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 14 Dec 2025 20:16:31 +0200 Subject: [PATCH 3/8] fix: allow text for interactive marker --- .../helpers/plotly_helper_functions.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 5189b33..bf2c770 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1531,7 +1531,7 @@ def _xy_for_reference( def _xy_at_cutoff( group: str, cutoff: float, fixed_time_horizon: float - ) -> tuple[Any, Any]: + ) -> tuple[Any, Any, Any]: row = ( rtichoke_curve_list["performance_data_ready_for_curve"] .filter( @@ -1541,13 +1541,13 @@ def _xy_at_cutoff( & pl.col("x").is_not_null() & pl.col("y").is_not_null() ) - .select(["x", "y"]) + .select(["x", "y", "text"]) .limit(1) ) if row.height == 0: - return None, None + return None, None, None r = row.row(0) - return r[0], r[1] + return r[0], r[1], r[2] non_interactive_curve = [] for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: @@ -1581,10 +1581,10 @@ def _xy_at_cutoff( marker_traces: list[go.Scatter] = [] for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"]: for group in rtichoke_curve_list["reference_group_keys"]: - x_val, y_val = ( + x_val, y_val, text_val = ( _xy_at_cutoff(group, initial_cutoff, fixed_time_horizon) if initial_cutoff is not None - else (None, None) + else (None, None, None) ) marker_traces.append( go.Scatter( @@ -1596,19 +1596,26 @@ def _xy_at_cutoff( "color": ( rtichoke_curve_list["colors_dictionary"].get(group) if rtichoke_curve_list["multiple_reference_groups"] - else "#f6e3be", + else "#f6e3be" ), "line": {"width": 3, "color": "black"}, }, name=f"{group} @ cutoff", legendgroup=group, hoverlabel=dict( - bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), - bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), - font_color="white", + bgcolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + font_color="black" + if not rtichoke_curve_list["multiple_reference_groups"] + else "white", ), showlegend=False, hoverinfo="text", + text=text_val, visible=fixed_time_horizon == initial_fixed_time_horizon, ) ) @@ -1765,6 +1772,7 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur go.Scatter( x=[], y=[], + text=[], mode="markers", marker={ "size": 12, @@ -1834,13 +1842,13 @@ def xy_at_cutoff(group, c): & pl.col("x").is_not_null() & pl.col("y").is_not_null() ) - .select(["x", "y"]) + .select(["x", "y", "text"]) .limit(1) ) if row.height == 0: - return None, None - r = row.row(0) # (x, y) - return r[0], r[1] + return None, None, None + r = row.row(0) + return r[0], r[1], r[2] steps = [ { From 7fc9ca650164de8bc6cc09955fae48d80d81878e Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 14 Dec 2025 20:34:48 +0200 Subject: [PATCH 4/8] fix: text being updated for interactive markers --- .../helpers/plotly_helper_functions.py | 61 ++++++++++--------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index bf2c770..acbed79 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1656,35 +1656,40 @@ def _xy_at_cutoff( ) ) - cutoff_steps = [ - { - "method": "restyle", - "args": [ - { - "x": [ - [xy[0]] if xy[0] is not None else [] - for xy in marker_points_at_cutoff - ], - "y": [ - [xy[1]] if xy[1] is not None else [] - for xy in marker_points_at_cutoff - ], - }, - cutoff_target_indices, - ], - "label": f"{cutoff:g}", - } - for cutoff in rtichoke_curve_list["cutoffs"] - for marker_points_at_cutoff in [ - [ - _xy_at_cutoff(group, cutoff, fixed_time_horizon) - if cutoff is not None - else (None, None) - for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"] - for group in rtichoke_curve_list["reference_group_keys"] - ] + def marker_values_for_cutoff( + cutoff: float, + ) -> tuple[list[list], list[list], list[list]]: + marker_values = [ + _xy_at_cutoff(group, cutoff, fixed_time_horizon) + if cutoff is not None + else (None, None, None) + for fixed_time_horizon in rtichoke_curve_list["fixed_time_horizons"] + for group in rtichoke_curve_list["reference_group_keys"] ] - ] + + xs = [[x] if x is not None else [] for x, _, _ in marker_values] + ys = [[y] if y is not None else [] for _, y, _ in marker_values] + texts = [[text] if text is not None else [] for _, _, text in marker_values] + + return xs, ys, texts + + cutoff_steps = [] + for cutoff in rtichoke_curve_list["cutoffs"]: + xs, ys, texts = marker_values_for_cutoff(cutoff) + cutoff_steps.append( + { + "method": "restyle", + "args": [ + { + "x": xs, + "y": ys, + "text": texts, + }, + cutoff_target_indices, + ], + "label": f"{cutoff:g}", + } + ) steps_fixed_time_horizon = [] total_traces = num_curve_traces + num_marker_traces + len(reference_traces) From e8fff83c0f2da17a3c3151fe545c684652126d10 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 14 Dec 2025 21:57:26 +0200 Subject: [PATCH 5/8] fix: try to sort hover text for binary-outcome plots --- .../helpers/plotly_helper_functions.py | 111 +++++++++++------- 1 file changed, 67 insertions(+), 44 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index acbed79..099b787 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1744,6 +1744,10 @@ def marker_values_for_cutoff( def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figure: + initial_cutoff = ( + rtichoke_curve_list["cutoffs"][0] if rtichoke_curve_list["cutoffs"] else None + ) + non_interactive_curve = [ go.Scatter( x=rtichoke_curve_list["performance_data_ready_for_curve"] @@ -1773,11 +1777,53 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur for group in rtichoke_curve_list["reference_group_keys"] ] + def xy_at_cutoff(group, c): + row = ( + rtichoke_curve_list["performance_data_ready_for_curve"] + .filter( + (pl.col("reference_group") == group) + & (pl.col("chosen_cutoff") == c) + & pl.col("x").is_not_null() + & pl.col("y").is_not_null() + ) + .select(["x", "y", "text"]) + .limit(1) + ) + if row.height == 0: + return None, None, None + r = row.row(0) + return r[0], r[1], r[2] + + def marker_values_for_cutoff( + cutoff: float, + ) -> tuple[list[list], list[list], list[list]]: + marker_values = [ + xy_at_cutoff(group, cutoff) + for group in rtichoke_curve_list["reference_group_keys"] + ] + + xs = [[x] if x is not None else [] for x, _, _ in marker_values] + ys = [[y] if y is not None else [] for _, y, _ in marker_values] + texts = [[text] if text is not None else [] for _, _, text in marker_values] + + return xs, ys, texts + + initial_xs, initial_ys, initial_texts = ( + marker_values_for_cutoff(initial_cutoff) + if initial_cutoff is not None + else ( + [[] for _ in rtichoke_curve_list["reference_group_keys"]], + [[] for _ in rtichoke_curve_list["reference_group_keys"]], + [[] for _ in rtichoke_curve_list["reference_group_keys"]], + ) + ) + initial_interactive_markers = [ go.Scatter( - x=[], - y=[], - text=[], + x=initial_xs[idx], + y=initial_ys[idx], + text=initial_texts[idx], + hovertext=initial_texts[idx], mode="markers", marker={ "size": 12, @@ -1798,7 +1844,7 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur showlegend=False, hoverinfo="text", ) - for group in rtichoke_curve_list["reference_group_keys"] + for idx, group in enumerate(rtichoke_curve_list["reference_group_keys"]) ] reference_traces = [ @@ -1838,47 +1884,24 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur ) ) - def xy_at_cutoff(group, c): - row = ( - rtichoke_curve_list["performance_data_ready_for_curve"] - .filter( - (pl.col("reference_group") == group) - & (pl.col("chosen_cutoff") == c) - & pl.col("x").is_not_null() - & pl.col("y").is_not_null() - ) - .select(["x", "y", "text"]) - .limit(1) + steps = [] + for cutoff in rtichoke_curve_list["cutoffs"]: + xs, ys, texts = marker_values_for_cutoff(cutoff) + steps.append( + { + "method": "restyle", + "args": [ + { + "x": xs, + "y": ys, + "text": texts, + "hovertext": texts, + }, + dyn_idx, + ], + "label": f"{cutoff:g}", + } ) - if row.height == 0: - return None, None, None - r = row.row(0) - return r[0], r[1], r[2] - - steps = [ - { - "method": "restyle", - "args": [ - { - "x": [ - [xy_at_cutoff(group, cutoff)[0]] - if xy_at_cutoff(group, cutoff)[0] is not None - else [] - for group in rtichoke_curve_list["reference_group_keys"] - ], - "y": [ - [xy_at_cutoff(group, cutoff)[1]] - if xy_at_cutoff(group, cutoff)[1] is not None - else [] - for group in rtichoke_curve_list["reference_group_keys"] - ], - }, - dyn_idx, - ], - "label": f"{cutoff:g}", - } - for cutoff in rtichoke_curve_list["cutoffs"] - ] slider_dict = _create_slider_dict( rtichoke_curve_list["animation_slider_prefix"], steps From eddfe6d11b9ada02928619d10ddc0807933243e9 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 14 Dec 2025 22:09:11 +0200 Subject: [PATCH 6/8] fix: close #233 --- src/rtichoke/helpers/plotly_helper_functions.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index 099b787..de3f2d8 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1823,7 +1823,7 @@ def marker_values_for_cutoff( x=initial_xs[idx], y=initial_ys[idx], text=initial_texts[idx], - hovertext=initial_texts[idx], + # hovertext=initial_texts[idx], mode="markers", marker={ "size": 12, @@ -1837,9 +1837,15 @@ def marker_values_for_cutoff( name=f"{group} @ cutoff", legendgroup=group, hoverlabel=dict( - bgcolor=rtichoke_curve_list["colors_dictionary"].get(group), - bordercolor=rtichoke_curve_list["colors_dictionary"].get(group), - font_color="white", + bgcolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + bordercolor="#f6e3be" + if not rtichoke_curve_list["multiple_reference_groups"] + else rtichoke_curve_list["colors_dictionary"].get(group), + font_color="black" + if not rtichoke_curve_list["multiple_reference_groups"] + else "white", ), showlegend=False, hoverinfo="text", @@ -1895,7 +1901,7 @@ def marker_values_for_cutoff( "x": xs, "y": ys, "text": texts, - "hovertext": texts, + # "hovertext": texts, }, dyn_idx, ], From 2bf19f30ed0855448ada16a063a7aca8ccc570c8 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 15 Dec 2025 09:00:54 +0200 Subject: [PATCH 7/8] fix: close #234 --- src/rtichoke/helpers/plotly_helper_functions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index de3f2d8..d3184a1 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -1573,7 +1573,7 @@ def _xy_at_cutoff( font_color="white", ), hoverinfo="text", - showlegend=fixed_time_horizon == initial_fixed_time_horizon, + showlegend=rtichoke_curve_list["multiple_reference_groups"], visible=fixed_time_horizon == initial_fixed_time_horizon, ) ) @@ -1735,6 +1735,7 @@ def marker_values_for_cutoff( axes_ranges=rtichoke_curve_list["axes_ranges"], x_label=rtichoke_curve_list["x_label"], y_label=rtichoke_curve_list["y_label"], + show_legend=rtichoke_curve_list["multiple_reference_groups"], ) return go.Figure( @@ -1772,7 +1773,7 @@ def _create_plotly_curve_binary(rtichoke_curve_list: dict[str, Any]) -> go.Figur font_color="white", ), hoverinfo="text", - showlegend=True, + showlegend=rtichoke_curve_list["multiple_reference_groups"], ) for group in rtichoke_curve_list["reference_group_keys"] ] @@ -1919,6 +1920,7 @@ def marker_values_for_cutoff( axes_ranges=rtichoke_curve_list["axes_ranges"], x_label=rtichoke_curve_list["x_label"], y_label=rtichoke_curve_list["y_label"], + show_legend=rtichoke_curve_list["multiple_reference_groups"], ) return go.Figure( @@ -1933,6 +1935,7 @@ def _create_curve_layout( axes_ranges: dict[str, list[float]] | None = None, x_label: str | None = None, y_label: str | None = None, + show_legend: bool = True, ) -> dict[str, Any]: sliders = slider_dict if isinstance(slider_dict, list) else [slider_dict] From 91e2043b6fef6f689273df5823f034006f332013 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Mon, 15 Dec 2025 09:57:07 +0200 Subject: [PATCH 8/8] fix: close #235 --- src/rtichoke/helpers/plotly_helper_functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/helpers/plotly_helper_functions.py index d3184a1..ab475be 100644 --- a/src/rtichoke/helpers/plotly_helper_functions.py +++ b/src/rtichoke/helpers/plotly_helper_functions.py @@ -25,6 +25,21 @@ "ppcr": "Predicted Positives", } +DEFAULT_MODEBAR_BUTTONS_TO_REMOVE = [ + "zoom2d", + "pan2d", + "select2d", + "lasso2d", + "zoomIn2d", + "zoomOut2d", + "autoScale2d", + "resetScale2d", + "hoverClosestCartesian", + "hoverCompareCartesian", + "toggleSpikelines", + "toImage", +] + def _create_rtichoke_plotly_curve_binary( probs: Dict[str, np.ndarray], @@ -1987,6 +2002,7 @@ def _create_curve_layout( } ], "sliders": sliders, + "modebar": {"remove": list(DEFAULT_MODEBAR_BUTTONS_TO_REMOVE)}, } if axes_ranges is not None: