diff --git a/AGENTS.md b/AGENTS.md
new file mode 100644
index 0000000..04f3079
--- /dev/null
+++ b/AGENTS.md
@@ -0,0 +1,77 @@
+# rtichoke Agent Information
+
+This document provides guidance for AI agents working on the `rtichoke` repository.
+
+## Development Environment
+
+To set up the development environment, follow these steps:
+
+1. **Install `uv`**: If you don't have `uv` installed, please follow the official installation instructions.
+2. **Create a virtual environment**: Use `uv venv` to create a virtual environment.
+3. **Install dependencies**: Install the project dependencies, including the `dev` dependencies, with the following command:
+
+ ```bash
+ uv pip install -e .[dev]
+ ```
+
+## Running Tests
+
+The test suite is run using `pytest`. To run the tests, use the following command:
+
+```bash
+uv run pytest
+```
+
+## Coding Conventions
+
+### Functional Programming
+
+Strive to use a functional programming style as much as possible. Avoid side effects and mutable state where practical.
+
+### Docstrings
+
+All exported functions must have NumPy-style docstrings. This is to ensure that the documentation is clear, consistent, and can be easily parsed by tools like `quartodoc`.
+
+Example of a NumPy-style docstring:
+
+```python
+def my_function(param1, param2):
+ """Summary of the function's purpose.
+
+ Parameters
+ ----------
+ param1 : int
+ Description of the first parameter.
+ param2 : str
+ Description of the second parameter.
+
+ Returns
+ -------
+ bool
+ Description of the return value.
+ """
+ # function body
+ return True
+```
+
+## Pre-commit Hooks
+
+This repository uses pre-commit hooks to ensure code quality and consistency. The following hooks are configured:
+
+* **`ruff-check`**: A linter to check for common errors and style issues.
+* **`ruff-format`**: A code formatter to ensure a consistent code style.
+* **`uv-lock`**: A hook to keep the `uv.lock` file up to date.
+
+Before committing, please ensure that the pre-commit hooks pass. You can run them manually on all files with `pre-commit run --all-files`.
+
+## Documentation
+
+The documentation for this project is built using `quartodoc`. The documentation is automatically built and deployed via GitHub Actions. There is no need to build the documentation manually.
+
+## Type Checking
+
+This project uses `ty` for type checking. To check for type errors, run the following command:
+
+```bash
+uv run ty check src tests
+```
diff --git a/pyproject.toml b/pyproject.toml
index 7ff976a..2388c6d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,9 +12,10 @@ dependencies = [
"polarstate==0.1.8",
"marimo>=0.17.0",
"pyarrow>=21.0.0",
+ "statsmodels>=0.14.0",
]
name = "rtichoke"
-version = "0.1.25"
+version = "0.1.26"
description = "interactive visualizations for performance of predictive models"
readme = "README.md"
diff --git a/src/rtichoke/__init__.py b/src/rtichoke/__init__.py
index cc5b47a..bc84b74 100644
--- a/src/rtichoke/__init__.py
+++ b/src/rtichoke/__init__.py
@@ -30,9 +30,10 @@
)
from rtichoke.discrimination.gains import plot_gains_curve as plot_gains_curve
-# from rtichoke.calibration.calibration import (
-# create_calibration_curve as create_calibration_curve,
-# )
+from rtichoke.calibration.calibration import (
+ create_calibration_curve as create_calibration_curve,
+ create_calibration_curve_times as create_calibration_curve_times,
+)
from rtichoke.utility.decision import (
create_decision_curve as create_decision_curve,
diff --git a/src/rtichoke/calibration/__init__.py b/src/rtichoke/calibration/__init__.py
index 4267999..190e74e 100644
--- a/src/rtichoke/calibration/__init__.py
+++ b/src/rtichoke/calibration/__init__.py
@@ -1,3 +1,7 @@
"""
Subpackage for Calibration
"""
+
+from .calibration import create_calibration_curve, create_calibration_curve_times
+
+__all__ = ["create_calibration_curve", "create_calibration_curve_times"]
diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py
index f9d32f3..4e245a9 100644
--- a/src/rtichoke/calibration/calibration.py
+++ b/src/rtichoke/calibration/calibration.py
@@ -2,21 +2,24 @@
A module for Calibration Curves
"""
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Union
# import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.graph_objs._figure import Figure
+import polars as pl
+import numpy as np
+
# from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r
def create_calibration_curve(
- probs: Dict[str, List[float]],
- reals: Dict[str, List[int]],
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
calibration_type: str = "discrete",
- size: Optional[int] = None,
- color_values: Optional[List[str]] = [
+ size: int = 600,
+ color_values: List[str] = [
"#1b9e77",
"#d95f02",
"#7570b3",
@@ -38,7 +41,6 @@ def create_calibration_curve(
"#D1603D",
"#585123",
],
- url_api: str = "http://localhost:4242/",
) -> Figure:
"""Creates Calibration Curve
@@ -55,40 +57,213 @@ def create_calibration_curve(
"""
pass
- # rtichoke_response = send_requests_to_rtichoke_r(
- # dictionary_to_send={
- # "probs": probs,
- # "reals": reals,
- # "size": size,
- # "color_values ": color_values,
- # },
- # url_api=url_api,
- # endpoint="create_calibration_curve_list",
- # )
-
- # calibration_curve_list = rtichoke_response.json()
-
- # calibration_curve_list["deciles_dat"] = pd.DataFrame.from_dict(
- # calibration_curve_list["deciles_dat"]
- # )
- # calibration_curve_list["smooth_dat"] = pd.DataFrame.from_dict(
- # calibration_curve_list["smooth_dat"]
- # )
- # calibration_curve_list["reference_data"] = pd.DataFrame.from_dict(
- # calibration_curve_list["reference_data"]
- # )
- # calibration_curve_list["histogram_for_calibration"] = pd.DataFrame.from_dict(
- # calibration_curve_list["histogram_for_calibration"]
- # )
-
- # calibration_curve = create_plotly_curve_from_calibration_curve_list(
- # calibration_curve_list=calibration_curve_list, calibration_type=calibration_type
- # )
-
- # return calibration_curve
-
-
-def create_plotly_curve_from_calibration_curve_list(
+ calibration_curve_list = _create_calibration_curve_list(
+ probs, reals, size=size, color_values=color_values
+ )
+
+ calibration_curve = _create_plotly_curve_from_calibration_curve_list(
+ calibration_curve_list, calibration_type=calibration_type
+ )
+
+ return calibration_curve
+
+
+def create_calibration_curve_times(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ times: Union[np.ndarray, Dict[str, np.ndarray]],
+ fixed_time_horizons: List[float],
+ heuristics_sets: List[Dict[str, str]],
+ calibration_type: str = "discrete",
+ size: int = 600,
+ color_values: List[str] = [
+ "#1b9e77",
+ "#d95f02",
+ "#7570b3",
+ "#e7298a",
+ "#07004D",
+ "#E6AB02",
+ "#FE5F55",
+ "#54494B",
+ "#006E90",
+ "#BC96E6",
+ "#52050A",
+ "#1F271B",
+ "#BE7C4D",
+ "#63768D",
+ "#08A045",
+ "#320A28",
+ "#82FF9E",
+ "#2176FF",
+ "#D1603D",
+ "#585123",
+ ],
+) -> Figure:
+ """Creates a time-dependent Calibration Curve with a slider for different time horizons."""
+
+ calibration_curve_list_times = _create_calibration_curve_list_times(
+ probs,
+ reals,
+ times,
+ fixed_time_horizons=fixed_time_horizons,
+ heuristics_sets=heuristics_sets,
+ size=size,
+ color_values=color_values,
+ )
+
+ fig = _create_plotly_curve_from_calibration_curve_list_times(
+ calibration_curve_list_times, calibration_type=calibration_type
+ )
+
+ return fig
+
+
+def _create_plotly_curve_from_calibration_curve_list_times(
+ calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete"
+) -> Figure:
+ """
+ Creates a plotly figure for time-dependent calibration curves.
+ """
+ fig = make_subplots(
+ rows=2, cols=1, shared_xaxes=True, x_title="Predicted", row_heights=[0.8, 0.2]
+ )
+
+ initial_horizon = calibration_curve_list["fixed_time_horizons"][0]
+
+ # Add traces for each horizon, initially visible only for the first horizon
+ for horizon in calibration_curve_list["fixed_time_horizons"]:
+ visible = horizon == initial_horizon
+
+ # Reference Line
+ fig.add_trace(
+ go.Scatter(
+ x=calibration_curve_list["reference_data"]["x"],
+ y=calibration_curve_list["reference_data"]["y"],
+ hovertext=calibration_curve_list["reference_data"]["text"],
+ name="Perfectly Calibrated",
+ legendgroup="Perfectly Calibrated",
+ hoverinfo="text",
+ line={"width": 2, "dash": "dot", "color": "#BEBEBE"},
+ showlegend=False,
+ visible=visible,
+ ),
+ row=1,
+ col=1,
+ )
+
+ for group in calibration_curve_list["reference_group_keys"]:
+ color = calibration_curve_list["colors_dictionary"][group][0]
+
+ # Calibration curve (discrete or smooth)
+ if calibration_type == "discrete":
+ data_subset = calibration_curve_list["deciles_dat"].filter(
+ (pl.col("reference_group") == group)
+ & (pl.col("fixed_time_horizon") == horizon)
+ )
+ mode = "lines+markers"
+ else: # smooth
+ data_subset = calibration_curve_list["smooth_dat"].filter(
+ (pl.col("reference_group") == group)
+ & (pl.col("fixed_time_horizon") == horizon)
+ )
+ mode = "lines+markers" if data_subset.height == 1 else "lines"
+
+ fig.add_trace(
+ go.Scatter(
+ x=data_subset["x"],
+ y=data_subset["y"],
+ hovertext=data_subset["text"],
+ name=group,
+ legendgroup=group,
+ hoverinfo="text",
+ mode=mode,
+ marker={"size": 10, "color": color},
+ visible=visible,
+ ),
+ row=1,
+ col=1,
+ )
+
+ # Histogram
+ hist_subset = calibration_curve_list["histogram_for_calibration"].filter(
+ (pl.col("reference_group") == group)
+ & (pl.col("fixed_time_horizon") == horizon)
+ )
+ fig.add_trace(
+ go.Bar(
+ x=hist_subset["mids"],
+ y=hist_subset["counts"],
+ hovertext=hist_subset["text"],
+ name=group,
+ width=0.01,
+ legendgroup=group,
+ hoverinfo="text",
+ marker_color=color,
+ showlegend=False,
+ opacity=0.4,
+ visible=visible,
+ ),
+ row=2,
+ col=1,
+ )
+
+ # Create slider
+ steps = []
+ num_traces_per_horizon = 1 + 2 * len(calibration_curve_list["reference_group_keys"])
+
+ for i, horizon in enumerate(calibration_curve_list["fixed_time_horizons"]):
+ visibility = [False] * (
+ num_traces_per_horizon * len(calibration_curve_list["fixed_time_horizons"])
+ )
+ for j in range(num_traces_per_horizon):
+ visibility[i * num_traces_per_horizon + j] = True
+ step = dict(
+ method="restyle",
+ args=[{"visible": visibility}],
+ label=str(horizon),
+ )
+ steps.append(step)
+
+ sliders = [
+ dict(
+ active=0,
+ currentvalue={"prefix": "Time Horizon: "},
+ pad={"t": 50},
+ steps=steps,
+ )
+ ]
+
+ # Layout
+ fig.update_layout(
+ sliders=sliders,
+ xaxis={
+ "showgrid": False,
+ "range": calibration_curve_list["axes_ranges"]["xaxis"],
+ },
+ yaxis={
+ "showgrid": False,
+ "range": calibration_curve_list["axes_ranges"]["yaxis"],
+ "title": "Observed",
+ },
+ barmode="overlay",
+ plot_bgcolor="rgba(0, 0, 0, 0)",
+ legend={
+ "orientation": "h",
+ "xanchor": "center",
+ "yanchor": "top",
+ "x": 0.5,
+ "y": 1.3,
+ "bgcolor": "rgba(0, 0, 0, 0)",
+ },
+ showlegend=calibration_curve_list["performance_type"][0] != "one model",
+ width=calibration_curve_list["size"][0][0],
+ height=calibration_curve_list["size"][0][0],
+ )
+
+ return fig
+
+
+def _create_plotly_curve_from_calibration_curve_list(
calibration_curve_list: Dict[str, Any], calibration_type: str = "discrete"
) -> Figure:
"""Create plotly curve from calibration curve list
@@ -124,16 +299,16 @@ def create_plotly_curve_from_calibration_curve_list(
calibration_curve.add_trace(
go.Scatter(
- x=calibration_curve_list["reference_data"]["x"].values.tolist(),
- y=calibration_curve_list["reference_data"]["y"].values.tolist(),
- hovertext=calibration_curve_list["reference_data"]["text"].values.tolist(),
+ x=calibration_curve_list["reference_data"]["x"],
+ y=calibration_curve_list["reference_data"]["y"],
+ hovertext=calibration_curve_list["reference_data"]["text"],
name="Perfectly Calibrated",
legendgroup="Perfectly Calibrated",
hoverinfo="text",
line={
"width": 2,
"dash": "dot",
- "color": calibration_curve_list["group_colors_vec"]["reference_line"][
+ "color": calibration_curve_list["colors_dictionary"]["reference_line"][
0
],
},
@@ -144,111 +319,120 @@ def create_plotly_curve_from_calibration_curve_list(
)
if calibration_type == "discrete":
- for reference_group in list(calibration_curve_list["group_colors_vec"].keys()):
- if any(
- calibration_curve_list["deciles_dat"]["reference_group"]
- == reference_group
- ):
- calibration_curve.add_trace(
- go.Scatter(
- x=calibration_curve_list["deciles_dat"]["x"][
- calibration_curve_list["deciles_dat"]["reference_group"]
- == reference_group
- ].values.tolist(),
- y=calibration_curve_list["deciles_dat"]["y"][
- calibration_curve_list["deciles_dat"]["reference_group"]
- == reference_group
- ].values.tolist(),
- hovertext=calibration_curve_list["deciles_dat"]["text"][
- calibration_curve_list["deciles_dat"]["reference_group"]
- == reference_group
- ].values.tolist(),
- name=reference_group,
- legendgroup=reference_group,
- hoverinfo="text",
- mode="lines+markers",
- marker={
- "size": 10,
- "color": calibration_curve_list["group_colors_vec"][
- reference_group
- ][0],
- },
- ),
- row=1,
- col=1,
- )
+ reference_groups = [
+ k
+ for k in calibration_curve_list["colors_dictionary"].keys()
+ if k != "reference_line"
+ ]
+ for reference_group in reference_groups:
+ dec_sub = calibration_curve_list["deciles_dat"].filter(
+ pl.col("reference_group") == reference_group
+ )
+
+ print(dec_sub)
+
+ calibration_curve.add_trace(
+ go.Scatter(
+ x=dec_sub.get_column("x").to_list(),
+ y=dec_sub.get_column("y").to_list(),
+ hovertext=dec_sub.get_column("text").to_list(),
+ name=reference_group,
+ legendgroup=reference_group,
+ hoverinfo="text",
+ mode="lines+markers",
+ marker={
+ "size": 10,
+ "color": calibration_curve_list["colors_dictionary"][
+ reference_group
+ ][0],
+ },
+ ),
+ row=1,
+ col=1,
+ )
+
+ hist = calibration_curve_list["histogram_for_calibration"]
+
+ for reference_group in reference_groups:
+ hist_sub = hist.filter(pl.col("reference_group") == reference_group)
+ if hist_sub.height == 0:
+ continue
+
+ calibration_curve.add_trace(
+ go.Bar(
+ x=hist_sub.get_column("mids").to_list(),
+ y=hist_sub.get_column("counts").to_list(),
+ hovertext=hist_sub.get_column("text").to_list(),
+ name=reference_group,
+ width=0.01,
+ legendgroup=reference_group,
+ hoverinfo="text",
+ marker_color=calibration_curve_list["colors_dictionary"][
+ reference_group
+ ][0],
+ showlegend=False,
+ opacity=0.4,
+ ),
+ row=2,
+ col=1,
+ )
if calibration_type == "smooth":
- for reference_group in list(calibration_curve_list["group_colors_vec"].keys()):
- if any(
- calibration_curve_list["smooth_dat"]["reference_group"]
- == reference_group
- ):
- calibration_curve.add_trace(
- go.Scatter(
- x=calibration_curve_list["smooth_dat"]["x"][
- calibration_curve_list["smooth_dat"]["reference_group"]
- == reference_group
- ].values.tolist(),
- y=calibration_curve_list["smooth_dat"]["y"][
- calibration_curve_list["smooth_dat"]["reference_group"]
- == reference_group
- ].values.tolist(),
- hovertext=calibration_curve_list["smooth_dat"]["text"][
- calibration_curve_list["smooth_dat"]["reference_group"]
- == reference_group
- ].values.tolist(),
- name=reference_group,
- legendgroup=reference_group,
- hoverinfo="text",
- mode="lines",
- marker={
- "size": 10,
- "color": calibration_curve_list["group_colors_vec"][
- reference_group
- ][0],
- },
- ),
- row=1,
- col=1,
- )
+ smooth_dat = calibration_curve_list["smooth_dat"]
+ reference_groups = [
+ k
+ for k in calibration_curve_list["colors_dictionary"].keys()
+ if k != "reference_line"
+ ]
+
+ for reference_group in reference_groups:
+ smooth_sub = smooth_dat.filter(pl.col("reference_group") == reference_group)
+ if smooth_sub.height == 0:
+ continue
+
+ mode = "lines+markers" if smooth_sub.height == 1 else "lines"
+
+ calibration_curve.add_trace(
+ go.Scatter(
+ x=smooth_sub.get_column("x").to_list(),
+ y=smooth_sub.get_column("y").to_list(),
+ hovertext=smooth_sub.get_column("text").to_list(),
+ name=reference_group,
+ legendgroup=reference_group,
+ hoverinfo="text",
+ mode=mode,
+ marker={
+ "size": 10,
+ "color": calibration_curve_list["colors_dictionary"][
+ reference_group
+ ][0],
+ },
+ ),
+ row=1,
+ col=1,
+ )
+
+ hist = calibration_curve_list["histogram_for_calibration"]
+
+ for reference_group in reference_groups:
+ hist_sub = hist.filter(pl.col("reference_group") == reference_group)
+ if hist_sub.height == 0:
+ continue
- for reference_group in list(calibration_curve_list["group_colors_vec"].keys()):
- if any(
- calibration_curve_list["histogram_for_calibration"]["reference_group"]
- == reference_group
- ):
calibration_curve.add_trace(
go.Bar(
- x=calibration_curve_list["histogram_for_calibration"]["mids"][
- calibration_curve_list["histogram_for_calibration"][
- "reference_group"
- ]
- == reference_group
- ].values.tolist(),
- y=calibration_curve_list["histogram_for_calibration"]["counts"][
- calibration_curve_list["histogram_for_calibration"][
- "reference_group"
- ]
- == reference_group
- ].values.tolist(),
- hovertext=calibration_curve_list["histogram_for_calibration"][
- "text"
- ][
- calibration_curve_list["histogram_for_calibration"][
- "reference_group"
- ]
- == reference_group
- ].values.tolist(),
+ x=hist_sub.get_column("mids").to_list(),
+ y=hist_sub.get_column("counts").to_list(),
+ hovertext=hist_sub.get_column("text").to_list(),
name=reference_group,
width=0.01,
legendgroup=reference_group,
hoverinfo="text",
- marker_color=calibration_curve_list["group_colors_vec"][
+ marker_color=calibration_curve_list["colors_dictionary"][
reference_group
][0],
showlegend=False,
- opacity=calibration_curve_list["histogram_opacity"][0],
+ opacity=0.4,
),
row=2,
col=1,
@@ -284,3 +468,669 @@ def create_plotly_curve_from_calibration_curve_list(
)
return calibration_curve
+
+
+def _make_deciles_dat_binary(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ n_bins: int = 10,
+) -> pl.DataFrame:
+ if isinstance(reals, dict):
+ reference_groups_keys = list(reals.keys())
+ y_list = [
+ np.asarray(reals[reference_group]).ravel()
+ for reference_group in reference_groups_keys
+ ]
+ lengths = np.array([len(y) for y in y_list], dtype=np.int64)
+ offsets = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(lengths)])
+ n_total = int(offsets[-1])
+
+ frames: list[pl.DataFrame] = []
+ for model, p_all in probs.items():
+ p_all = np.asarray(p_all).ravel()
+ if p_all.shape[0] != n_total:
+ raise ValueError(
+ f"probs['{model}'] length={p_all.shape[0]} does not match "
+ f"sum of population sizes={n_total}."
+ )
+
+ for i, pop in enumerate(reference_groups_keys):
+ start = int(offsets[i])
+ end = int(offsets[i + 1])
+
+ frames.append(
+ pl.DataFrame(
+ {
+ "reference_group": pop,
+ "model": model,
+ "prob": p_all[start:end].astype(float, copy=False),
+ "real": y_list[i].astype(float, copy=False),
+ }
+ )
+ )
+
+ df = pl.concat(frames, how="vertical")
+
+ else:
+ y = np.asarray(reals).ravel()
+ n = y.shape[0]
+ frames = []
+ for model, p in probs.items():
+ p = np.asarray(p).ravel()
+ if p.shape[0] != n:
+ raise ValueError(
+ f"probs['{model}'] length={p.shape[0]} does not match reals length={n}."
+ )
+ frames.append(
+ pl.DataFrame(
+ {
+ "reference_group": model,
+ "model": model,
+ "prob": p.astype(float, copy=False),
+ "real": y.astype(float, copy=False),
+ }
+ )
+ )
+
+ df = pl.concat(frames, how="vertical")
+
+ labels = [str(i) for i in range(1, n_bins + 1)]
+
+ df = df.with_columns(
+ [
+ pl.col("prob").cast(pl.Float64),
+ pl.col("real").cast(pl.Float64),
+ pl.col("prob")
+ .qcut(n_bins, labels=labels, allow_duplicates=True)
+ .over(["reference_group", "model"])
+ .alias("decile"),
+ ]
+ ).with_columns(pl.col("decile").cast(pl.Int32))
+
+ deciles_data = (
+ df.group_by(["reference_group", "model", "decile"])
+ .agg(
+ [
+ pl.len().alias("n"),
+ pl.mean("prob").alias("x"),
+ pl.mean("real").alias("y"),
+ pl.sum("real").alias("n_reals"),
+ ]
+ )
+ .sort(["reference_group", "model", "decile"])
+ )
+
+ return deciles_data
+
+
+def _check_performance_type_by_probs_and_reals(
+ probs: Dict[str, np.ndarray], reals: Union[np.ndarray, Dict[str, np.ndarray]]
+) -> str:
+ if isinstance(reals, dict) and len(reals) > 1:
+ return "multiple populations"
+ if len(probs) > 1:
+ return "multiple models"
+ return "one model"
+
+
+def _create_calibration_curve_list(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ size: int = 600,
+ color_values: List[str] = [
+ "#1b9e77",
+ "#d95f02",
+ "#7570b3",
+ "#e7298a",
+ "#07004D",
+ "#E6AB02",
+ "#FE5F55",
+ "#54494B",
+ "#006E90",
+ "#BC96E6",
+ "#52050A",
+ "#1F271B",
+ "#BE7C4D",
+ "#63768D",
+ "#08A045",
+ "#320A28",
+ "#82FF9E",
+ "#2176FF",
+ "#D1603D",
+ "#585123",
+ ],
+) -> Dict[str, Any]:
+ deciles_data = _make_deciles_dat_binary(probs, reals)
+ performance_type = _check_performance_type_by_probs_and_reals(probs, reals)
+ smooth_dat = _calculate_smooth_curve(probs, reals, performance_type)
+
+ deciles_data, smooth_dat = _add_hover_text_to_calibration_data(
+ deciles_data, smooth_dat, performance_type
+ )
+
+ reference_data = _create_reference_data_for_calibration_curve()
+
+ reference_groups = deciles_data["reference_group"].unique().to_list()
+
+ colors_dictionary = _create_colors_dictionary_for_calibration(
+ reference_groups, color_values, performance_type
+ )
+
+ print("histogram for calibration")
+
+ histogram_for_calibration = _create_histogram_for_calibration(probs)
+
+ print(histogram_for_calibration)
+
+ limits = _define_limits_for_calibration_plot(deciles_data)
+ axes_ranges = {"xaxis": limits, "yaxis": limits}
+
+ smooth_dat = _calculate_smooth_curve(probs, reals, performance_type)
+
+ calibration_curve_list = {
+ "deciles_dat": deciles_data,
+ "smooth_dat": smooth_dat,
+ "reference_data": reference_data,
+ "histogram_for_calibration": histogram_for_calibration,
+ # "histogram_opacity": [0.4],
+ "axes_ranges": axes_ranges,
+ "colors_dictionary": colors_dictionary,
+ "performance_type": [performance_type],
+ "size": [(size, size)],
+ }
+
+ return calibration_curve_list
+
+
+def _create_reference_data_for_calibration_curve() -> pl.DataFrame:
+ x_ref = np.linspace(0, 1, 101)
+ reference_data = pl.DataFrame({"x": x_ref, "y": x_ref})
+ reference_data = reference_data.with_columns(
+ pl.concat_str(
+ [
+ pl.lit("Perfectly Calibrated
Predicted: "),
+ pl.col("x").map_elements(lambda x: f"{x:.3f}", return_dtype=pl.Utf8),
+ pl.lit("
Observed: "),
+ pl.col("y").map_elements(lambda y: f"{y:.3f}", return_dtype=pl.Utf8),
+ ]
+ ).alias("text")
+ )
+ return reference_data
+
+
+def _calculate_smooth_curve(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ performance_type: str,
+) -> pl.DataFrame:
+ """
+ Calculate the smoothed calibration curve using lowess.
+ """
+ from statsmodels.nonparametric.smoothers_lowess import lowess
+
+ smooth_frames = []
+
+ # Helper function to process a single probability and real array
+ def process_single_array(p, r, group_name):
+ if len(np.unique(p)) == 1:
+ return pl.DataFrame(
+ {
+ "x": [np.unique(p)[0]],
+ "y": [np.mean(r)],
+ "reference_group": [group_name],
+ }
+ )
+ else:
+ # lowess returns a 2D array where the first column is x and the second is y
+ smoothed = lowess(r, p, it=0)
+ xout = np.linspace(0, 1, 101)
+ yout = np.interp(xout, smoothed[:, 0], smoothed[:, 1])
+ return pl.DataFrame(
+ {"x": xout, "y": yout, "reference_group": [group_name] * len(xout)}
+ )
+
+ if isinstance(reals, dict):
+ for model_name, prob_array in probs.items():
+ # This logic assumes that for multiple populations, one model's probs are evaluated against multiple real outcomes.
+ # This might need adjustment based on the exact structure for multiple models and populations.
+ if len(probs) == 1 and len(reals) > 1: # One model, multiple populations
+ for pop_name, real_array in reals.items():
+ frame = process_single_array(prob_array, real_array, pop_name)
+ smooth_frames.append(frame)
+ else: # Multiple models, potentially multiple populations
+ for group_name in reals.keys():
+ if group_name in probs:
+ frame = process_single_array(
+ probs[group_name], reals[group_name], group_name
+ )
+ smooth_frames.append(frame)
+
+ else: # reals is a single numpy array
+ for group_name, prob_array in probs.items():
+ frame = process_single_array(prob_array, reals, group_name)
+ smooth_frames.append(frame)
+
+ if not smooth_frames:
+ return pl.DataFrame(
+ schema={
+ "x": pl.Float64,
+ "y": pl.Float64,
+ "reference_group": pl.Utf8,
+ "text": pl.Utf8,
+ }
+ )
+
+ smooth_dat = pl.concat(smooth_frames)
+
+ if performance_type != "one model":
+ smooth_dat = smooth_dat.with_columns(
+ pl.concat_str(
+ [
+ pl.lit(""),
+ pl.col("reference_group"),
+ pl.lit("
Predicted: "),
+ pl.col("x").map_elements(
+ lambda x: f"{x:.3f}", return_dtype=pl.Utf8
+ ),
+ pl.lit("
Observed: "),
+ pl.col("y").map_elements(
+ lambda y: f"{y:.3f}", return_dtype=pl.Utf8
+ ),
+ ]
+ ).alias("text")
+ )
+ else:
+ smooth_dat = smooth_dat.with_columns(
+ pl.concat_str(
+ [
+ pl.lit("Predicted: "),
+ pl.col("x").map_elements(
+ lambda x: f"{x:.3f}", return_dtype=pl.Utf8
+ ),
+ pl.lit("
Observed: "),
+ pl.col("y").map_elements(
+ lambda y: f"{y:.3f}", return_dtype=pl.Utf8
+ ),
+ ]
+ ).alias("text")
+ )
+ return smooth_dat
+
+
+def _add_hover_text_to_calibration_data(
+ deciles_dat: pl.DataFrame,
+ smooth_dat: pl.DataFrame,
+ performance_type: str,
+) -> tuple[pl.DataFrame, pl.DataFrame]:
+ """Adds hover text to the deciles and smooth dataframes."""
+ if performance_type != "one model":
+ deciles_dat = deciles_dat.with_columns(
+ pl.concat_str(
+ [
+ pl.lit(""),
+ pl.col("reference_group"),
+ pl.lit("
Predicted: "),
+ pl.col("x").round(3),
+ pl.lit("
Observed: "),
+ pl.col("y").round(3),
+ pl.lit(" ( "),
+ pl.col("n_reals"),
+ pl.lit(" / "),
+ pl.col("n"),
+ pl.lit(" )"),
+ ]
+ ).alias("text")
+ )
+ smooth_dat = smooth_dat.with_columns(
+ pl.concat_str(
+ [
+ pl.lit(""),
+ pl.col("reference_group"),
+ pl.lit("
Predicted: "),
+ pl.col("x").round(3),
+ pl.lit("
Observed: "),
+ pl.col("y").round(3),
+ ]
+ ).alias("text")
+ )
+ else:
+ deciles_dat = deciles_dat.with_columns(
+ pl.concat_str(
+ [
+ pl.lit("Predicted: "),
+ pl.col("x").round(3),
+ pl.lit("
Observed: "),
+ pl.col("y").round(3),
+ pl.lit(" ( "),
+ pl.col("n_reals"),
+ pl.lit(" / "),
+ pl.col("n"),
+ pl.lit(" )"),
+ ]
+ ).alias("text")
+ )
+ smooth_dat = smooth_dat.with_columns(
+ pl.concat_str(
+ [
+ pl.lit("Predicted: "),
+ pl.col("x").round(3),
+ pl.lit("
Observed: "),
+ pl.col("y").round(3),
+ ]
+ ).alias("text")
+ )
+ return deciles_dat, smooth_dat
+
+
+def _create_colors_dictionary_for_calibration(
+ reference_groups: List[str],
+ color_values: List[str],
+ performance_type: str = "one model",
+) -> Dict[str, List[str]]:
+ if performance_type == "one model":
+ colors = ["black"]
+ else:
+ colors = color_values[: len(reference_groups)]
+
+ return {
+ "reference_line": ["#BEBEBE"],
+ **{
+ group: [colors[i % len(colors)]] for i, group in enumerate(reference_groups)
+ },
+ }
+
+
+def _create_histogram_for_calibration(probs: Dict[str, np.ndarray]) -> pl.DataFrame:
+ hist_dfs = []
+ for group, prob_values in probs.items():
+ counts, mids = np.histogram(prob_values, bins=np.arange(0, 1.01, 0.01))
+ hist_df = pl.DataFrame(
+ {"mids": mids[:-1] + 0.005, "counts": counts, "reference_group": group}
+ )
+ hist_df = hist_df.with_columns(
+ (
+ pl.col("counts").cast(str)
+ + " observations in ["
+ + (pl.col("mids") - 0.005).round(3).cast(str)
+ + ", "
+ + (pl.col("mids") + 0.005).round(3).cast(str)
+ + "]"
+ ).alias("text")
+ )
+ hist_dfs.append(hist_df)
+
+ histogram_for_calibration = pl.concat(hist_dfs)
+
+ return histogram_for_calibration
+
+
+def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float]:
+ if deciles_dat.height == 1:
+ lower_bound, upper_bound = 0.0, 1.0
+ else:
+ lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min())))
+ upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max()))
+
+ return [
+ lower_bound - (upper_bound - lower_bound) * 0.05,
+ upper_bound + (upper_bound - lower_bound) * 0.05,
+ ]
+
+
+def _build_initial_df_for_times(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ times: Union[np.ndarray, Dict[str, np.ndarray]],
+) -> pl.DataFrame:
+ """Builds the initial DataFrame for time-dependent calibration curves."""
+
+ # Convert all inputs to dictionaries of arrays to unify processing
+ if not isinstance(reals, dict):
+ reals = {"single_population": np.asarray(reals)}
+ if not isinstance(times, dict):
+ times = {"single_population": np.asarray(times)}
+
+ # Verify matching keys and lengths
+ if reals.keys() != times.keys():
+ raise ValueError("Keys in reals and times dictionaries do not match.")
+ for key in reals:
+ if len(reals[key]) != len(times[key]):
+ raise ValueError(
+ f"Length mismatch for population '{key}' in reals and times."
+ )
+
+ # Create a base DataFrame with population data
+ population_frames = []
+ for key in reals:
+ population_frames.append(
+ pl.DataFrame(
+ {
+ "reference_group": key,
+ "real": reals[key],
+ "time": times[key],
+ }
+ )
+ )
+ base_df = pl.concat(population_frames)
+
+ # Prepare model predictions
+ # Single model case
+ if len(probs) == 1:
+ model_name, prob_array = next(iter(probs.items()))
+ if len(prob_array) != base_df.height:
+ raise ValueError(
+ f"Length of probabilities for model '{model_name}' does not match total number of observations."
+ )
+ return base_df.with_columns(
+ pl.Series("prob", prob_array), pl.lit(model_name).alias("model")
+ )
+
+ # Multiple models
+ else:
+ # One model per population (keys must match)
+ if probs.keys() == reals.keys():
+ prob_frames = []
+ for model_name, prob_array in probs.items():
+ pop_df = base_df.filter(pl.col("reference_group") == model_name)
+ if len(prob_array) != pop_df.height:
+ raise ValueError(
+ f"Length of probabilities for model '{model_name}' does not match population size."
+ )
+ prob_frames.append(
+ pop_df.with_columns(
+ pl.Series("prob", prob_array), pl.lit(model_name).alias("model")
+ )
+ )
+ return pl.concat(prob_frames)
+ # Multiple models on a single population
+ elif len(reals) == 1:
+ final_frames = []
+ for model_name, prob_array in probs.items():
+ if len(prob_array) != base_df.height:
+ raise ValueError(
+ f"Length of probabilities for model '{model_name}' does not match population size."
+ )
+ final_frames.append(
+ base_df.with_columns(
+ pl.Series("prob", prob_array),
+ pl.lit(model_name).alias(
+ "reference_group"
+ ), # Overwrite reference_group with model name
+ )
+ )
+ return pl.concat(final_frames)
+
+ raise ValueError("Unsupported combination of probs, reals, and times structures.")
+
+
+def _apply_heuristics_and_censoring(
+ df: pl.DataFrame,
+ horizon: float,
+ censoring_heuristic: str,
+ competing_heuristic: str,
+) -> pl.DataFrame:
+ """
+ Applies censoring and competing risk heuristics to the data for a given time horizon.
+ """
+ # Administrative censoring: outcomes after horizon are negative
+ df_adj = df.with_columns(
+ pl.when(pl.col("time") > horizon)
+ .then(0)
+ .otherwise(pl.col("real"))
+ .alias("real")
+ )
+
+ # Heuristics for events before or at horizon
+ if censoring_heuristic == "excluded":
+ df_adj = df_adj.filter(~((pl.col("real") == 0) & (pl.col("time") <= horizon)))
+
+ if competing_heuristic == "excluded":
+ df_adj = df_adj.filter(~((pl.col("real") == 2) & (pl.col("time") <= horizon)))
+ elif competing_heuristic == "adjusted_as_negative":
+ df_adj = df_adj.with_columns(
+ pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon))
+ .then(0)
+ .otherwise(pl.col("real"))
+ .alias("real")
+ )
+ elif competing_heuristic == "adjusted_as_composite":
+ df_adj = df_adj.with_columns(
+ pl.when((pl.col("real") == 2) & (pl.col("time") <= horizon))
+ .then(1)
+ .otherwise(pl.col("real"))
+ .alias("real")
+ )
+
+ return df_adj
+
+
+def _create_calibration_curve_list_times(
+ probs: Dict[str, np.ndarray],
+ reals: Union[np.ndarray, Dict[str, np.ndarray]],
+ times: Union[np.ndarray, Dict[str, np.ndarray]],
+ fixed_time_horizons: List[float],
+ heuristics_sets: List[Dict[str, str]],
+ size: int = 600,
+ color_values: List[str] = [
+ "#1b9e77",
+ "#d95f02",
+ "#7570b3",
+ "#e7298a",
+ "#07004D",
+ "#E6AB02",
+ "#FE5F55",
+ "#54494B",
+ "#006E90",
+ "#BC96E6",
+ "#52050A",
+ "#1F271B",
+ "#BE7C4D",
+ "#63768D",
+ "#08A045",
+ "#320A28",
+ "#82FF9E",
+ "#2176FF",
+ "#D1603D",
+ "#585123",
+ ],
+) -> Dict[str, Any]:
+ """
+ Creates the data structures needed for a time-dependent calibration curve plot.
+ """
+ # Part 1: Prepare initial dataframe from inputs
+ initial_df = _build_initial_df_for_times(probs, reals, times)
+
+ # Part 2: Iterate and generate calibration data for each horizon/heuristic
+ all_deciles = []
+ all_smooth = []
+ all_histograms = []
+
+ performance_type = _check_performance_type_by_probs_and_reals(probs, reals)
+
+ for horizon in fixed_time_horizons:
+ for heuristics in heuristics_sets:
+ censoring_heuristic = heuristics["censoring_heuristic"]
+ competing_heuristic = heuristics["competing_heuristic"]
+
+ if (
+ censoring_heuristic == "adjusted"
+ or competing_heuristic == "adjusted_as_censored"
+ ):
+ continue
+
+ df_adj = _apply_heuristics_and_censoring(
+ initial_df, horizon, censoring_heuristic, competing_heuristic
+ )
+
+ if df_adj.height == 0:
+ continue
+
+ # Re-create probs and reals dicts for helpers
+ probs_adj = {
+ group[0]: group_df["prob"].to_numpy()
+ for group, group_df in df_adj.group_by("reference_group")
+ }
+ reals_adj = {
+ group[0]: group_df["real"].to_numpy()
+ for group, group_df in df_adj.group_by("reference_group")
+ }
+ # If single population initially, reals_adj should be an array
+ if not isinstance(reals, dict) and len(probs) == 1:
+ reals_adj = next(iter(reals_adj.values()))
+
+ # Deciles
+ deciles_data = _make_deciles_dat_binary(probs_adj, reals_adj)
+ all_deciles.append(
+ deciles_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))
+ )
+
+ # Smooth curve
+ smooth_data = _calculate_smooth_curve(
+ probs_adj, reals_adj, performance_type
+ )
+ all_smooth.append(
+ smooth_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))
+ )
+
+ # Histogram
+ hist_data = _create_histogram_for_calibration(probs_adj)
+ all_histograms.append(
+ hist_data.with_columns(pl.lit(horizon).alias("fixed_time_horizon"))
+ )
+
+ # Part 3: Combine results and create final dictionary
+ if not all_deciles:
+ raise ValueError(
+ "No data remaining after applying heuristics and time horizons."
+ )
+ deciles_dat_final = pl.concat(all_deciles)
+ smooth_dat_final = pl.concat(all_smooth)
+ histogram_final = pl.concat(all_histograms)
+
+ # Add hover text
+ deciles_dat_final, smooth_dat_final = _add_hover_text_to_calibration_data(
+ deciles_dat_final, smooth_dat_final, performance_type
+ )
+
+ reference_data = _create_reference_data_for_calibration_curve()
+ reference_groups = deciles_dat_final["reference_group"].unique().to_list()
+ colors_dictionary = _create_colors_dictionary_for_calibration(
+ reference_groups, color_values, performance_type
+ )
+ limits = _define_limits_for_calibration_plot(deciles_dat_final)
+ axes_ranges = {"xaxis": limits, "yaxis": limits}
+
+ calibration_curve_list = {
+ "deciles_dat": deciles_dat_final,
+ "smooth_dat": smooth_dat_final,
+ "reference_data": reference_data,
+ "histogram_for_calibration": histogram_final,
+ "axes_ranges": axes_ranges,
+ "colors_dictionary": colors_dictionary,
+ "performance_type": [performance_type],
+ "size": [(size, size)],
+ "fixed_time_horizons": fixed_time_horizons,
+ "reference_group_keys": reference_groups,
+ }
+
+ return calibration_curve_list
diff --git a/src/rtichoke/discrimination/gains.py b/src/rtichoke/discrimination/gains.py
index 59366f4..8cf7c8d 100644
--- a/src/rtichoke/discrimination/gains.py
+++ b/src/rtichoke/discrimination/gains.py
@@ -4,7 +4,7 @@
from typing import Dict, List, Sequence, Union
from plotly.graph_objs._figure import Figure
-from rtichoke.helpers.plotly_helper_functions import (
+from rtichoke.processing.plotly_helper_functions import (
_create_rtichoke_plotly_curve_times,
_create_rtichoke_plotly_curve_binary,
_plot_rtichoke_curve_binary,
diff --git a/src/rtichoke/discrimination/lift.py b/src/rtichoke/discrimination/lift.py
index 5f358af..e3e394c 100644
--- a/src/rtichoke/discrimination/lift.py
+++ b/src/rtichoke/discrimination/lift.py
@@ -4,7 +4,7 @@
from typing import Dict, List, Sequence, Union
from plotly.graph_objs._figure import Figure
-from rtichoke.helpers.plotly_helper_functions import (
+from rtichoke.processing.plotly_helper_functions import (
_create_rtichoke_plotly_curve_times,
_create_rtichoke_plotly_curve_binary,
_plot_rtichoke_curve_binary,
diff --git a/src/rtichoke/discrimination/precision_recall.py b/src/rtichoke/discrimination/precision_recall.py
index 565cf5c..1a3d7a0 100644
--- a/src/rtichoke/discrimination/precision_recall.py
+++ b/src/rtichoke/discrimination/precision_recall.py
@@ -4,7 +4,7 @@
from typing import Dict, List, Sequence, Union
from plotly.graph_objs._figure import Figure
-from rtichoke.helpers.plotly_helper_functions import (
+from rtichoke.processing.plotly_helper_functions import (
_create_rtichoke_plotly_curve_times,
_create_rtichoke_plotly_curve_binary,
_plot_rtichoke_curve_binary,
diff --git a/src/rtichoke/discrimination/roc.py b/src/rtichoke/discrimination/roc.py
index d8a8ed0..9bcc653 100644
--- a/src/rtichoke/discrimination/roc.py
+++ b/src/rtichoke/discrimination/roc.py
@@ -4,7 +4,7 @@
from typing import Dict, List, Union, Sequence
from plotly.graph_objs._figure import Figure
-from rtichoke.helpers.plotly_helper_functions import (
+from rtichoke.processing.plotly_helper_functions import (
_create_rtichoke_plotly_curve_times,
_create_rtichoke_plotly_curve_binary,
_plot_rtichoke_curve_binary,
diff --git a/src/rtichoke/helpers/sandbox_observable_helpers.py b/src/rtichoke/helpers/sandbox_observable_helpers.py
deleted file mode 100644
index d3cc352..0000000
--- a/src/rtichoke/helpers/sandbox_observable_helpers.py
+++ /dev/null
@@ -1,1770 +0,0 @@
-# from lifelines import AalenJohansenFitter
-import pandas as pd
-import numpy as np
-import polars as pl
-from polarstate import predict_aj_estimates
-from polarstate import prepare_event_table
-from typing import Dict, Union
-from collections.abc import Sequence
-
-
-def _enum_dataframe(column_name: str, values: Sequence[str]) -> pl.DataFrame:
- """Create a single-column DataFrame with an enum dtype."""
- enum_values = list(dict.fromkeys(values))
- enum_dtype = pl.Enum(enum_values)
- return pl.DataFrame({column_name: pl.Series(values, dtype=enum_dtype)})
-
-
-# def extract_aj_estimate(data_to_adjust, fixed_time_horizons):
-# """
-# Python implementation of the R extract_aj_estimate function for Aalen-Johansen estimation.
-
-# Parameters:
-# data_to_adjust (pd.DataFrame): DataFrame containing survival data
-# fixed_time_horizons (list or float): Time points at which to evaluate the survival
-
-# Returns:
-# pd.DataFrame: DataFrame with Aalen-Johansen estimates
-# """
-
-# # Ensure fixed_time_horizons is a list
-# if not isinstance(fixed_time_horizons, list):
-# fixed_time_horizons = [fixed_time_horizons]
-
-# # Create a categorical version of reals for stratification
-# data = data_to_adjust.copy()
-# data["reals_cat"] = pd.Categorical(
-# data["reals_labels"],
-# categories=[
-# "real_negatives",
-# "real_positives",
-# "real_competing",
-# "real_censored",
-# ],
-# ordered=True,
-# )
-
-# # Get unique strata values
-# strata_values = data["strata"].unique()
-
-# event_map = {
-# "real_negatives": 0, # Treat as censored
-# "real_positives": 1, # Event of interest
-# "real_competing": 2, # Competing risk
-# "real_censored": 0, # Censored
-# }
-
-# data["event_code"] = data["reals_labels"].map(event_map)
-
-# # Initialize result dataframes
-# results = []
-
-# # For each stratum, fit Aalen-Johansen model
-# for stratum in strata_values:
-# # Filter data for current stratum
-# stratum_data = data.loc[data["strata"] == stratum]
-
-# # Initialize Aalen-Johansen fitter
-# ajf = AalenJohansenFitter()
-# ajf_competing = AalenJohansenFitter()
-
-# # Fit the model
-# ajf.fit(stratum_data["times"], stratum_data["event_code"], event_of_interest=1)
-
-# ajf_competing.fit(
-# stratum_data["times"], stratum_data["event_code"], event_of_interest=2
-# )
-
-# # Calculate cumulative incidence at fixed time horizons
-# for t in fixed_time_horizons:
-# n = len(stratum_data)
-# real_positives_est = ajf.predict(t)
-# real_competing_est = ajf_competing.predict(t)
-# real_negatives_est = 1 - real_positives_est - real_competing_est
-
-# states = ["real_negatives", "real_positives", "real_competing"]
-# estimates = [real_negatives_est, real_positives_est, real_competing_est]
-
-# for state, estimate in zip(states, estimates):
-# results.append(
-# {
-# "strata": stratum,
-# "reals": state,
-# "fixed_time_horizon": t,
-# "reals_estimate": estimate * n,
-# }
-# )
-
-# # Convert to DataFrame
-# result_df = pd.DataFrame(results)
-
-# # Convert strata to categorical if needed
-# result_df["strata"] = pd.Categorical(result_df["strata"])
-
-# return result_df
-
-
-def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame:
- def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame:
- probs = group["probs"].to_numpy()
- columns_to_add = []
-
- breaks = create_breaks_values(probs, "probability_threshold", by)
- if "probability_threshold" in stratified_by:
- last_bin_index = len(breaks) - 2
-
- bin_indices = np.digitize(probs, bins=breaks, right=False) - 1
- bin_indices = np.where(probs == 1.0, last_bin_index, bin_indices)
-
- lower_bounds = breaks[bin_indices]
- upper_bounds = breaks[bin_indices + 1]
-
- include_upper_bounds = bin_indices == last_bin_index
-
- strata_prob_labels = np.where(
- include_upper_bounds,
- [f"[{lo:.2f}, {hi:.2f}]" for lo, hi in zip(lower_bounds, upper_bounds)],
- [f"[{lo:.2f}, {hi:.2f})" for lo, hi in zip(lower_bounds, upper_bounds)],
- ).astype(str)
-
- columns_to_add.append(
- pl.Series("strata_probability_threshold", strata_prob_labels)
- )
-
- if "ppcr" in stratified_by:
- # --- Compute strata_ppcr as equal-frequency quantile bins by rank ---
- by = float(by)
- q = int(round(1 / by)) # e.g. 0.2 -> 5 bins
-
- probs = np.asarray(probs, float)
-
- edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear")
-
- edges = np.maximum.accumulate(edges)
-
- edges[0] = 0.0
- edges[-1] = 1.0
-
- bin_idx = np.digitize(probs, bins=edges[1:-1], right=True)
-
- s = str(by)
- decimals = len(s.split(".")[-1]) if "." in s else 0
-
- labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)]
-
- strata_labels = np.array([labels[i] for i in bin_idx], dtype=object)
-
- columns_to_add.append(
- pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels))
- )
- return group.with_columns(columns_to_add)
-
- # Apply per-group transformation
- grouped = data.partition_by("reference_group", as_dict=True)
- transformed_groups = [transform_group(group, by) for group in grouped.values()]
- return pl.concat(transformed_groups)
-
-
-def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.DataFrame:
- s_by = str(by)
- decimals = len(s_by.split(".")[-1]) if "." in s_by else 0
- fmt = f"{{:.{decimals}f}}"
-
- if stratified_by == "probability_threshold":
- upper_bound = breaks[1:] # breaks
- lower_bound = breaks[:-1] # np.roll(upper_bound, 1)
- # lower_bound[0] = 0.0
- mid_point = upper_bound - by / 2
- include_lower_bound = lower_bound > -0.1
- include_upper_bound = upper_bound == 1.0 # upper_bound != 0.0
- # chosen_cutoff = upper_bound
- strata = format_strata_column(
- lower_bound=lower_bound,
- upper_bound=upper_bound,
- include_lower_bound=include_lower_bound,
- include_upper_bound=include_upper_bound,
- decimals=2,
- )
-
- elif stratified_by == "ppcr":
- strata_mid = breaks[1:]
- lower_bound = strata_mid - by / 2
- upper_bound = strata_mid + by / 2
- mid_point = breaks[1:]
- include_lower_bound = np.ones_like(strata_mid, dtype=bool)
- include_upper_bound = np.zeros_like(strata_mid, dtype=bool)
- # chosen_cutoff = strata_mid
- strata = np.array([fmt.format(x) for x in strata_mid], dtype=object)
- else:
- raise ValueError(f"Unsupported stratified_by: {stratified_by}")
-
- bins_df = pl.DataFrame(
- {
- "strata": pl.Series(strata),
- "lower_bound": lower_bound,
- "upper_bound": upper_bound,
- "mid_point": mid_point,
- "include_lower_bound": include_lower_bound,
- "include_upper_bound": include_upper_bound,
- # "chosen_cutoff": chosen_cutoff,
- "stratified_by": [stratified_by] * len(strata),
- }
- )
-
- cutoffs_df = pl.DataFrame({"chosen_cutoff": breaks})
-
- return bins_df.join(cutoffs_df, how="cross")
-
-
-def format_strata_column(
- lower_bound: list[float],
- upper_bound: list[float],
- include_lower_bound: list[bool],
- include_upper_bound: list[bool],
- decimals: int = 3,
-) -> list[str]:
- return [
- f"{'[' if ilb else '('}"
- f"{round(lb, decimals):.{decimals}f}, "
- f"{round(ub, decimals):.{decimals}f}"
- f"{']' if iub else ')'}"
- for lb, ub, ilb, iub in zip(
- lower_bound, upper_bound, include_lower_bound, include_upper_bound
- )
- ]
-
-
-def format_strata_interval(
- lower: float, upper: float, include_lower: bool, include_upper: bool
-) -> str:
- left = "[" if include_lower else "("
- right = "]" if include_upper else ")"
- return f"{left}{lower:.3f}, {upper:.3f}{right}"
-
-
-def create_breaks_values(probs_vec, stratified_by, by):
- if stratified_by != "probability_threshold":
- breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1 / by) + 1))
- else:
- breaks = np.round(
- np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1])
- )
- return breaks
-
-
-def _create_aj_data_combinations_binary(
- reference_groups: Sequence[str],
- stratified_by: Sequence[str],
- by: float,
- breaks: Sequence[float],
-) -> pl.DataFrame:
- dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by]
-
- strata_combinations = pl.concat(dfs, how="vertical")
-
- strata_cats = (
- strata_combinations.select(pl.col("strata").unique(maintain_order=True))
- .to_series()
- .to_list()
- )
-
- strata_enum = pl.Enum(strata_cats)
- stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"])
-
- strata_combinations = strata_combinations.with_columns(
- [
- pl.col("strata").cast(strata_enum),
- pl.col("stratified_by").cast(stratified_by_enum),
- ]
- )
-
- # Define values for Cartesian product
- reals_labels = ["real_negatives", "real_positives"]
-
- combinations_frames: list[pl.DataFrame] = [
- _enum_dataframe("reference_group", reference_groups),
- strata_combinations,
- _enum_dataframe("reals_labels", reals_labels),
- ]
-
- result = combinations_frames[0]
- for frame in combinations_frames[1:]:
- result = result.join(frame, how="cross")
-
- return result
-
-
-def create_aj_data_combinations(
- reference_groups: Sequence[str],
- heuristics_sets: list[Dict],
- fixed_time_horizons: Sequence[float],
- stratified_by: Sequence[str],
- by: float,
- breaks: Sequence[float],
- risk_set_scope: Sequence[str] = ["within_stratum", "pooled_by_cutoff"],
-) -> pl.DataFrame:
- dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by]
- strata_combinations = pl.concat(dfs, how="vertical")
-
- # strata_enum = pl.Enum(strata_combinations["strata"])
-
- strata_cats = (
- strata_combinations.select(pl.col("strata").unique(maintain_order=True))
- .to_series()
- .to_list()
- )
-
- strata_enum = pl.Enum(strata_cats)
- stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"])
-
- strata_combinations = strata_combinations.with_columns(
- [
- pl.col("strata").cast(strata_enum),
- pl.col("stratified_by").cast(stratified_by_enum),
- ]
- )
-
- risk_set_scope_combinations = pl.DataFrame(
- {
- "risk_set_scope": pl.Series(risk_set_scope).cast(
- pl.Enum(["within_stratum", "pooled_by_cutoff"])
- )
- }
- )
-
- # Define values for Cartesian product
- reals_labels = [
- "real_negatives",
- "real_positives",
- "real_competing",
- "real_censored",
- ]
-
- heuristics_combinations = pl.DataFrame(heuristics_sets)
-
- censoring_heuristics_enum = pl.Enum(
- heuristics_combinations["censoring_heuristic"].unique(maintain_order=True)
- )
- competing_heuristics_enum = pl.Enum(
- heuristics_combinations["competing_heuristic"].unique(maintain_order=True)
- )
-
- combinations_frames: list[pl.DataFrame] = [
- _enum_dataframe("reference_group", reference_groups),
- pl.DataFrame(
- {"fixed_time_horizon": pl.Series(fixed_time_horizons, dtype=pl.Float64)}
- ),
- heuristics_combinations.with_columns(
- [
- pl.col("censoring_heuristic").cast(censoring_heuristics_enum),
- pl.col("competing_heuristic").cast(competing_heuristics_enum),
- ]
- ),
- strata_combinations,
- risk_set_scope_combinations,
- _enum_dataframe("reals_labels", reals_labels),
- ]
-
- result = combinations_frames[0]
- for frame in combinations_frames[1:]:
- result = result.join(frame, how="cross")
-
- return result
-
-
-def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame:
- # Identify id_vars and value_vars
- id_vars = [col for col in data.columns if not col.startswith("strata_")]
- value_vars = [col for col in data.columns if col.startswith("strata_")]
-
- # Perform the melt (equivalent to pandas.melt)
- data_long = data.melt(
- id_vars=id_vars,
- value_vars=value_vars,
- variable_name="stratified_by",
- value_name="strata",
- )
-
- stratified_by_labels = ["probability_threshold", "ppcr"]
- stratified_by_enum = pl.Enum(stratified_by_labels)
-
- # Remove "strata_" prefix from the 'stratified_by' column
- data_long = data_long.with_columns(
- pl.col("stratified_by").str.replace("^strata_", "").cast(stratified_by_enum)
- )
-
- return data_long
-
-
-def map_reals_to_labels_polars(data: pl.DataFrame) -> pl.DataFrame:
- return data.with_columns(
- [
- pl.when(pl.col("reals") == 0)
- .then("real_negatives")
- .when(pl.col("reals") == 1)
- .then("real_positives")
- .when(pl.col("reals") == 2)
- .then("real_competing")
- .otherwise("real_censored")
- .alias("reals")
- ]
- )
-
-
-def update_administrative_censoring_polars(data: pl.DataFrame) -> pl.DataFrame:
- data = data.with_columns(
- [
- pl.when(
- (pl.col("times") > pl.col("fixed_time_horizon"))
- & (pl.col("reals_labels") == "real_positives")
- )
- .then(pl.lit("real_negatives"))
- .when(
- (pl.col("times") < pl.col("fixed_time_horizon"))
- & (pl.col("reals_labels") == "real_negatives")
- )
- .then(pl.lit("real_censored"))
- .otherwise(pl.col("reals_labels"))
- .alias("reals_labels")
- ]
- )
-
- return data
-
-
-def create_aj_data(
- reference_group_data,
- breaks,
- censoring_heuristic,
- competing_heuristic,
- fixed_time_horizons,
- stratified_by: Sequence[str],
- full_event_table: bool = False,
- risk_set_scope: Sequence[str] = ["within_stratum"],
-):
- """
- Create AJ estimates per strata based on censoring and competing heuristicss.
- """
-
- def aj_estimates_with_cross(df, extra_cols):
- return df.join(pl.DataFrame(extra_cols), how="cross")
-
- exploded = assign_and_explode_polars(reference_group_data, fixed_time_horizons)
-
- event_table = prepare_event_table(reference_group_data)
-
- # TODO: solve strata in the pipeline
-
- excluded_events = _extract_excluded_events(
- event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic
- )
-
- aj_dfs = []
- for rscope in risk_set_scope:
- aj_res = _aj_adjusted_events(
- reference_group_data,
- breaks,
- exploded,
- censoring_heuristic,
- competing_heuristic,
- fixed_time_horizons,
- stratified_by,
- full_event_table,
- rscope,
- )
-
- aj_res = aj_res.select(
- [
- "strata",
- "times",
- "chosen_cutoff",
- "real_negatives_est",
- "real_positives_est",
- "real_competing_est",
- "estimate_origin",
- "fixed_time_horizon",
- "risk_set_scope",
- ]
- )
-
- aj_dfs.append(aj_res)
-
- aj_df = pl.concat(aj_dfs, how="vertical")
-
- result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left")
-
- return aj_estimates_with_cross(
- result,
- {
- "censoring_heuristic": censoring_heuristic,
- "competing_heuristic": competing_heuristic,
- },
- ).select(
- [
- "strata",
- "chosen_cutoff",
- "fixed_time_horizon",
- "times",
- "real_negatives_est",
- "real_positives_est",
- "real_competing_est",
- "real_censored_est",
- "censoring_heuristic",
- "competing_heuristic",
- "estimate_origin",
- "risk_set_scope",
- ]
- )
-
-
-def _extract_excluded_events(
- event_table: pl.DataFrame,
- fixed_time_horizons: list[float],
- censoring_heuristic: str,
- competing_heuristic: str,
-) -> pl.DataFrame:
- horizons_df = pl.DataFrame({"times": fixed_time_horizons}).sort("times")
-
- excluded_events = horizons_df.join_asof(
- event_table.with_columns(
- pl.col("count_0").cum_sum().cast(pl.Float64).alias("real_censored_est"),
- pl.col("count_2").cum_sum().cast(pl.Float64).alias("real_competing_est"),
- ).select(
- pl.col("times"),
- pl.col("real_censored_est"),
- pl.col("real_competing_est"),
- ),
- left_on="times",
- right_on="times",
- ).with_columns([pl.col("times").alias("fixed_time_horizon")])
-
- if censoring_heuristic != "excluded":
- excluded_events = excluded_events.with_columns(
- pl.lit(0.0).alias("real_censored_est")
- )
-
- if competing_heuristic != "excluded":
- excluded_events = excluded_events.with_columns(
- pl.lit(0.0).alias("real_competing_est")
- )
-
- return excluded_events
-
-
-def extract_crude_estimate_polars(data: pl.DataFrame) -> pl.DataFrame:
- all_combinations = data.select(["strata", "reals", "fixed_time_horizon"]).unique()
-
- counts = data.group_by(["strata", "reals", "fixed_time_horizon"]).agg(
- pl.count().alias("reals_estimate")
- )
-
- return all_combinations.join(
- counts, on=["strata", "reals", "fixed_time_horizon"], how="left"
- ).with_columns([pl.col("reals_estimate").fill_null(0).cast(pl.Int64)])
-
-
-# def update_administrative_censoring(data_to_adjust: pd.DataFrame) -> pd.DataFrame:
-# pl_df = pl.from_pandas(data_to_adjust)
-
-# # Perform the transformation using polars
-# pl_result = pl_df.with_columns(
-# pl.when(
-# (pl.col("times") > pl.col("fixed_time_horizon")) &
-# (pl.col("reals") == "real_positives")
-# ).then(
-# "real_negatives"
-# ).when(
-# (pl.col("times") < pl.col("fixed_time_horizon")) &
-# (pl.col("reals") == "real_negatives")
-# ).then(
-# "real_censored"
-# ).otherwise(
-# pl.col("reals")
-# ).alias("reals")
-# )
-
-# # Convert back to pandas DataFrame and return
-# result_pandas = pl_result.to_pandas()
-
-# return result_pandas
-
-
-def extract_aj_estimate_by_cutoffs(
- data_to_adjust, horizons, breaks, stratified_by, full_event_table: bool
-):
- # n = data_to_adjust.height
-
- counts_per_strata = (
- data_to_adjust.group_by(
- ["strata", "stratified_by", "upper_bound", "lower_bound"]
- )
- .len(name="strata_count")
- .with_columns(pl.col("strata_count").cast(pl.Float64))
- )
-
- aj_estimates_predicted_positives = pl.DataFrame()
- aj_estimates_predicted_negatives = pl.DataFrame()
-
- for stratification_criteria in stratified_by:
- for chosen_cutoff in breaks:
- if stratification_criteria == "probability_threshold":
- mask_predicted_positives = (pl.col("upper_bound") > chosen_cutoff) & (
- pl.col("stratified_by") == "probability_threshold"
- )
- mask_predicted_negatives = (pl.col("upper_bound") <= chosen_cutoff) & (
- pl.col("stratified_by") == "probability_threshold"
- )
-
- elif stratification_criteria == "ppcr":
- mask_predicted_positives = (
- pl.col("lower_bound") > 1 - chosen_cutoff
- ) & (pl.col("stratified_by") == "ppcr")
- mask_predicted_negatives = (
- pl.col("lower_bound") <= 1 - chosen_cutoff
- ) & (pl.col("stratified_by") == "ppcr")
-
- predicted_positives = data_to_adjust.filter(mask_predicted_positives)
- predicted_negatives = data_to_adjust.filter(mask_predicted_negatives)
-
- counts_per_strata_predicted_positives = counts_per_strata.filter(
- mask_predicted_positives
- )
- counts_per_strata_predicted_negatives = counts_per_strata.filter(
- mask_predicted_negatives
- )
-
- event_table_predicted_positives = prepare_event_table(predicted_positives)
- event_table_predicted_negatives = prepare_event_table(predicted_negatives)
-
- aj_estimate_predicted_positives = (
- (
- predict_aj_estimates(
- event_table_predicted_positives,
- pl.Series(horizons),
- full_event_table,
- )
- .with_columns(
- pl.lit(chosen_cutoff).alias("chosen_cutoff"),
- pl.lit(stratification_criteria)
- .alias("stratified_by")
- .cast(pl.Enum(["probability_threshold", "ppcr"])),
- )
- .join(
- counts_per_strata_predicted_positives,
- on=["stratified_by"],
- how="left",
- )
- .with_columns(
- [
- (
- pl.col("state_occupancy_probability_0")
- * pl.col("strata_count")
- ).alias("real_negatives_est"),
- (
- pl.col("state_occupancy_probability_1")
- * pl.col("strata_count")
- ).alias("real_positives_est"),
- (
- pl.col("state_occupancy_probability_2")
- * pl.col("strata_count")
- ).alias("real_competing_est"),
- ]
- )
- )
- .select(
- [
- "strata",
- # "stratified_by",
- "times",
- "chosen_cutoff",
- "real_negatives_est",
- "real_positives_est",
- "real_competing_est",
- "estimate_origin",
- ]
- )
- .with_columns([pl.col("times").alias("fixed_time_horizon")])
- )
-
- aj_estimate_predicted_negatives = (
- (
- predict_aj_estimates(
- event_table_predicted_negatives,
- pl.Series(horizons),
- full_event_table,
- )
- .with_columns(
- pl.lit(chosen_cutoff).alias("chosen_cutoff"),
- pl.lit(stratification_criteria)
- .alias("stratified_by")
- .cast(pl.Enum(["probability_threshold", "ppcr"])),
- )
- .join(
- counts_per_strata_predicted_negatives,
- on=["stratified_by"],
- how="left",
- )
- .with_columns(
- [
- (
- pl.col("state_occupancy_probability_0")
- * pl.col("strata_count")
- ).alias("real_negatives_est"),
- (
- pl.col("state_occupancy_probability_1")
- * pl.col("strata_count")
- ).alias("real_positives_est"),
- (
- pl.col("state_occupancy_probability_2")
- * pl.col("strata_count")
- ).alias("real_competing_est"),
- ]
- )
- )
- .select(
- [
- "strata",
- # "stratified_by",
- "times",
- "chosen_cutoff",
- "real_negatives_est",
- "real_positives_est",
- "real_competing_est",
- "estimate_origin",
- ]
- )
- .with_columns([pl.col("times").alias("fixed_time_horizon")])
- )
-
- aj_estimates_predicted_negatives = pl.concat(
- [aj_estimates_predicted_negatives, aj_estimate_predicted_negatives],
- how="vertical",
- )
-
- aj_estimates_predicted_positives = pl.concat(
- [aj_estimates_predicted_positives, aj_estimate_predicted_positives],
- how="vertical",
- )
-
- aj_estimate_by_cutoffs = pl.concat(
- [aj_estimates_predicted_negatives, aj_estimates_predicted_positives],
- how="vertical",
- )
-
- return aj_estimate_by_cutoffs
-
-
-def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: bool):
- n = data_to_adjust.height
-
- event_table = prepare_event_table(data_to_adjust)
-
- aj_estimate_for_strata_polars = predict_aj_estimates(
- event_table, pl.Series(horizons), full_event_table
- )
-
- if len(horizons) == 1:
- aj_estimate_for_strata_polars = aj_estimate_for_strata_polars.with_columns(
- pl.lit(horizons[0]).alias("fixed_time_horizon")
- )
-
- else:
- fixed_df = aj_estimate_for_strata_polars.filter(
- pl.col("estimate_origin") == "fixed_time_horizons"
- ).with_columns([pl.col("times").alias("fixed_time_horizon")])
-
- event_df = (
- aj_estimate_for_strata_polars.filter(
- pl.col("estimate_origin") == "event_table"
- )
- .with_columns([pl.lit(horizons).alias("fixed_time_horizon")])
- .explode("fixed_time_horizon")
- )
-
- aj_estimate_for_strata_polars = pl.concat(
- [fixed_df, event_df], how="vertical"
- ).sort("estimate_origin", "fixed_time_horizon", "times")
-
- return aj_estimate_for_strata_polars.with_columns(
- [
- (pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"),
- (pl.col("state_occupancy_probability_1") * n).alias("real_positives_est"),
- (pl.col("state_occupancy_probability_2") * n).alias("real_competing_est"),
- pl.col("fixed_time_horizon").cast(pl.Float64),
- pl.lit(data_to_adjust["strata"][0]).alias("strata"),
- ]
- ).select(
- [
- "strata",
- "times",
- "fixed_time_horizon",
- "real_negatives_est",
- "real_positives_est",
- "real_competing_est",
- pl.col("estimate_origin"),
- ]
- )
-
-
-def assign_and_explode_polars(
- data: pl.DataFrame, fixed_time_horizons: list[float]
-) -> pl.DataFrame:
- return (
- data.with_columns(pl.lit(fixed_time_horizons).alias("fixed_time_horizon"))
- .explode("fixed_time_horizon")
- .with_columns(pl.col("fixed_time_horizon").cast(pl.Float64))
- )
-
-
-def _create_list_data_to_adjust_binary(
- aj_data_combinations: pl.DataFrame,
- probs_dict: Dict[str, np.ndarray],
- reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
- stratified_by,
- by,
-) -> Dict[str, pl.DataFrame]:
- reference_group_labels = list(probs_dict.keys())
-
- if isinstance(reals_dict, dict):
- num_keys_reals = len(reals_dict)
- else:
- num_keys_reals = 1
-
- reference_group_enum = pl.Enum(reference_group_labels)
-
- strata_enum_dtype = aj_data_combinations.schema["strata"]
-
- if len(probs_dict) == 1:
- probs_array = np.asarray(probs_dict[reference_group_labels[0]])
-
- data_to_adjust = pl.DataFrame(
- {
- "reference_group": np.repeat(reference_group_labels, len(probs_array)),
- "probs": probs_array,
- "reals": reals_dict,
- }
- ).with_columns(pl.col("reference_group").cast(reference_group_enum))
-
- elif num_keys_reals == 1:
- data_to_adjust = pl.DataFrame(
- {
- "reference_group": np.repeat(reference_group_labels, len(reals_dict)),
- "probs": np.concatenate(
- [probs_dict[group] for group in reference_group_labels]
- ),
- "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
- }
- ).with_columns(pl.col("reference_group").cast(reference_group_enum))
-
- elif isinstance(reals_dict, dict):
- data_to_adjust = (
- pl.DataFrame(
- {
- "reference_group": list(probs_dict.keys()),
- "probs": list(probs_dict.values()),
- "reals": list(reals_dict.values()),
- }
- )
- .explode(["probs", "reals"])
- .with_columns(pl.col("reference_group").cast(reference_group_enum))
- )
-
- data_to_adjust = add_cutoff_strata(
- data_to_adjust, by=by, stratified_by=stratified_by
- )
-
- data_to_adjust = pivot_longer_strata(data_to_adjust)
-
- data_to_adjust = (
- data_to_adjust.with_columns([pl.col("strata")])
- .with_columns(pl.col("strata").cast(strata_enum_dtype))
- .join(
- aj_data_combinations.select(
- pl.col("strata"),
- pl.col("stratified_by"),
- pl.col("upper_bound"),
- pl.col("lower_bound"),
- ).unique(),
- how="left",
- on=["strata", "stratified_by"],
- )
- )
-
- reals_labels = ["real_negatives", "real_positives"]
-
- reals_enum = pl.Enum(reals_labels)
-
- reals_map = {0: "real_negatives", 1: "real_positives"}
-
- data_to_adjust = data_to_adjust.with_columns(
- pl.col("reals")
- .replace_strict(reals_map, return_dtype=reals_enum)
- .alias("reals_labels")
- )
-
- list_data_to_adjust = {
- group[0]: df
- for group, df in data_to_adjust.partition_by(
- "reference_group", as_dict=True
- ).items()
- }
-
- return list_data_to_adjust
-
-
-def _create_list_data_to_adjust(
- aj_data_combinations: pl.DataFrame,
- probs_dict: Dict[str, np.ndarray],
- reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
- times_dict: Union[np.ndarray, Dict[str, np.ndarray]],
- stratified_by,
- by,
-) -> Dict[str, pl.DataFrame]:
- # reference_groups = list(probs_dict.keys())
- reference_group_labels = list(probs_dict.keys())
-
- if isinstance(reals_dict, dict):
- num_keys_reals = len(reals_dict)
- else:
- num_keys_reals = 1
-
- # num_reals = len(reals_dict)
-
- reference_group_enum = pl.Enum(reference_group_labels)
-
- strata_enum_dtype = aj_data_combinations.schema["strata"]
-
- if len(probs_dict) == 1:
- probs_array = np.asarray(probs_dict[reference_group_labels[0]])
-
- if isinstance(reals_dict, dict):
- reals_array = np.asarray(reals_dict[0])
- else:
- reals_array = np.asarray(reals_dict)
-
- if isinstance(times_dict, dict):
- times_array = np.asarray(times_dict[0])
- else:
- times_array = np.asarray(times_dict)
-
- data_to_adjust = pl.DataFrame(
- {
- "reference_group": np.repeat(reference_group_labels, len(probs_array)),
- "probs": probs_array,
- "reals": reals_array,
- "times": times_array,
- }
- ).with_columns(pl.col("reference_group").cast(reference_group_enum))
-
- elif num_keys_reals == 1:
- reals_array = np.asarray(reals_dict)
- times_array = np.asarray(times_dict)
- n = len(reals_array)
-
- data_to_adjust = pl.DataFrame(
- {
- "reference_group": np.repeat(reference_group_labels, n),
- "probs": np.concatenate(
- [np.asarray(probs_dict[g]) for g in reference_group_labels]
- ),
- "reals": np.tile(reals_array, len(reference_group_labels)),
- "times": np.tile(times_array, len(reference_group_labels)),
- }
- ).with_columns(pl.col("reference_group").cast(reference_group_enum))
-
- elif isinstance(reals_dict, dict) and isinstance(times_dict, dict):
- data_to_adjust = (
- pl.DataFrame(
- {
- "reference_group": reference_group_labels,
- "probs": list(probs_dict.values()),
- "reals": list(reals_dict.values()),
- "times": list(times_dict.values()),
- }
- )
- .explode(["probs", "reals", "times"])
- .with_columns(pl.col("reference_group").cast(reference_group_enum))
- )
-
- data_to_adjust = add_cutoff_strata(
- data_to_adjust, by=by, stratified_by=stratified_by
- )
-
- data_to_adjust = pivot_longer_strata(data_to_adjust)
-
- data_to_adjust = (
- data_to_adjust.with_columns([pl.col("strata")])
- .with_columns(pl.col("strata").cast(strata_enum_dtype))
- .join(
- aj_data_combinations.select(
- pl.col("strata"),
- pl.col("stratified_by"),
- pl.col("upper_bound"),
- pl.col("lower_bound"),
- ).unique(),
- how="left",
- on=["strata", "stratified_by"],
- )
- )
-
- reals_labels = [
- "real_negatives",
- "real_positives",
- "real_competing",
- "real_censored",
- ]
-
- reals_enum = pl.Enum(reals_labels)
-
- # Map reals values to strings
- reals_map = {0: "real_negatives", 2: "real_competing", 1: "real_positives"}
-
- data_to_adjust = data_to_adjust.with_columns(
- pl.col("reals")
- .replace_strict(reals_map, return_dtype=reals_enum)
- .alias("reals_labels")
- )
-
- # Partition by reference_group
- list_data_to_adjust = {
- group[0]: df
- for group, df in data_to_adjust.partition_by(
- "reference_group", as_dict=True
- ).items()
- }
-
- return list_data_to_adjust
-
-
-def ensure_no_categorical(df: pd.DataFrame) -> pd.DataFrame:
- df = df.copy()
- for col in df.select_dtypes(include="category").columns:
- df[col] = df[col].astype(str)
- return df
-
-
-def extract_aj_estimate_by_heuristics(
- df: pl.DataFrame,
- breaks: Sequence[float],
- heuristics_sets: list[dict],
- fixed_time_horizons: list[float],
- stratified_by: Sequence[str],
- risk_set_scope: Sequence[str] = ["within_stratum"],
-) -> pl.DataFrame:
- aj_dfs = []
-
- for heuristic in heuristics_sets:
- censoring = heuristic["censoring_heuristic"]
- competing = heuristic["competing_heuristic"]
-
- aj_df = create_aj_data(
- df,
- breaks,
- censoring,
- competing,
- fixed_time_horizons,
- stratified_by=stratified_by,
- full_event_table=False,
- risk_set_scope=risk_set_scope,
- ).with_columns(
- [
- pl.lit(censoring).alias("censoring_heuristic"),
- pl.lit(competing).alias("competing_heuristic"),
- ]
- )
-
- aj_dfs.append(aj_df)
-
- aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"])
-
- aj_estimates_unpivoted = aj_estimates_data.unpivot(
- index=[
- "strata",
- "chosen_cutoff",
- "fixed_time_horizon",
- "censoring_heuristic",
- "competing_heuristic",
- "risk_set_scope",
- ],
- variable_name="reals_labels",
- value_name="reals_estimate",
- )
-
- return aj_estimates_unpivoted
-
-
-def _create_adjusted_data_binary(
- list_data_to_adjust: dict[str, pl.DataFrame],
- breaks: Sequence[float],
- stratified_by: Sequence[str],
-) -> pl.DataFrame:
- long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical")
-
- adjusted_data_binary = (
- long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"])
- .agg(pl.count().alias("reals_estimate"))
- .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
- )
-
- return adjusted_data_binary
-
-
-def create_adjusted_data(
- list_data_to_adjust: dict[str, pl.DataFrame],
- heuristics_sets: list[dict[str, str]],
- fixed_time_horizons: list[float],
- breaks: Sequence[float],
- stratified_by: Sequence[str],
- risk_set_scope: Sequence[str] = ["within_stratum"],
-) -> pl.DataFrame:
- all_results = []
-
- reference_groups = list(list_data_to_adjust.keys())
- reference_group_enum = pl.Enum(reference_groups)
-
- heuristics_df = pl.DataFrame(heuristics_sets)
- censoring_heuristic_enum = pl.Enum(
- heuristics_df["censoring_heuristic"].unique(maintain_order=True)
- )
- competing_heuristic_enum = pl.Enum(
- heuristics_df["competing_heuristic"].unique(maintain_order=True)
- )
-
- for reference_group, df in list_data_to_adjust.items():
- input_df = df.select(
- ["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"]
- )
-
- aj_result = extract_aj_estimate_by_heuristics(
- input_df,
- breaks,
- heuristics_sets=heuristics_sets,
- fixed_time_horizons=fixed_time_horizons,
- stratified_by=stratified_by,
- risk_set_scope=risk_set_scope,
- )
-
- aj_result_with_group = aj_result.with_columns(
- [
- pl.lit(reference_group)
- .cast(reference_group_enum)
- .alias("reference_group")
- ]
- )
-
- all_results.append(aj_result_with_group)
-
- reals_enum_dtype = pl.Enum(
- [
- "real_negatives",
- "real_positives",
- "real_competing",
- "real_censored",
- ]
- )
-
- return (
- pl.concat(all_results)
- .with_columns([pl.col("reference_group").cast(reference_group_enum)])
- .with_columns(
- [
- pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype),
- pl.col("censoring_heuristic").cast(censoring_heuristic_enum),
- pl.col("competing_heuristic").cast(competing_heuristic_enum),
- ]
- )
- )
-
-
-def _cast_and_join_adjusted_data_binary(
- aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame
-) -> pl.DataFrame:
- strata_enum_dtype = aj_data_combinations.schema["strata"]
-
- aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns(
- pl.col("strata").cast(strata_enum_dtype)
- )
-
- final_adjusted_data_polars = (
- (
- aj_data_combinations.with_columns([pl.col("strata")]).join(
- aj_estimates_data,
- on=[
- "strata",
- "stratified_by",
- "reals_labels",
- "reference_group",
- "chosen_cutoff",
- ],
- how="left",
- )
- )
- .with_columns(
- pl.when(
- (
- (pl.col("chosen_cutoff") >= pl.col("upper_bound"))
- & (pl.col("stratified_by") == "probability_threshold")
- )
- | (
- ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point"))
- & (pl.col("stratified_by") == "ppcr")
- )
- )
- .then(pl.lit("predicted_negatives"))
- .otherwise(pl.lit("predicted_positives"))
- .cast(pl.Enum(["predicted_negatives", "predicted_positives"]))
- .alias("prediction_label")
- )
- .with_columns(
- (
- pl.when(
- (pl.col("prediction_label") == pl.lit("predicted_positives"))
- & (pl.col("reals_labels") == pl.lit("real_positives"))
- )
- .then(pl.lit("true_positives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_positives"))
- & (pl.col("reals_labels") == pl.lit("real_negatives"))
- )
- .then(pl.lit("false_positives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_negatives"))
- & (pl.col("reals_labels") == pl.lit("real_negatives"))
- )
- .then(pl.lit("true_negatives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_negatives"))
- & (pl.col("reals_labels") == pl.lit("real_positives"))
- )
- .then(pl.lit("false_negatives"))
- .cast(
- pl.Enum(
- [
- "true_positives",
- "false_positives",
- "true_negatives",
- "false_negatives",
- ]
- )
- )
- ).alias("classification_outcome")
- )
- ).with_columns(pl.col("reals_estimate").fill_null(0))
-
- return final_adjusted_data_polars
-
-
-def cast_and_join_adjusted_data(
- aj_data_combinations, aj_estimates_data
-) -> pl.DataFrame:
- strata_enum_dtype = aj_data_combinations.schema["strata"]
-
- aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns(
- pl.col("strata").cast(strata_enum_dtype)
- )
-
- final_adjusted_data_polars = (
- aj_data_combinations.with_columns([pl.col("strata")])
- .join(
- aj_estimates_data,
- on=[
- "strata",
- "fixed_time_horizon",
- "censoring_heuristic",
- "competing_heuristic",
- "reals_labels",
- "reference_group",
- "chosen_cutoff",
- "risk_set_scope",
- ],
- how="left",
- )
- .with_columns(
- pl.when(
- (
- (pl.col("chosen_cutoff") >= pl.col("upper_bound"))
- & (pl.col("stratified_by") == "probability_threshold")
- )
- | (
- ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point"))
- & (pl.col("stratified_by") == "ppcr")
- )
- )
- .then(pl.lit("predicted_negatives"))
- .otherwise(pl.lit("predicted_positives"))
- .cast(pl.Enum(["predicted_negatives", "predicted_positives"]))
- .alias("prediction_label")
- )
- .with_columns(
- (
- pl.when(
- (pl.col("prediction_label") == pl.lit("predicted_positives"))
- & (pl.col("reals_labels") == pl.lit("real_positives"))
- )
- .then(pl.lit("true_positives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_positives"))
- & (pl.col("reals_labels") == pl.lit("real_negatives"))
- )
- .then(pl.lit("false_positives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_negatives"))
- & (pl.col("reals_labels") == pl.lit("real_negatives"))
- )
- .then(pl.lit("true_negatives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_negatives"))
- & (pl.col("reals_labels") == pl.lit("real_positives"))
- )
- .then(pl.lit("false_negatives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_negatives"))
- & (pl.col("reals_labels") == pl.lit("real_competing"))
- & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative"))
- )
- .then(pl.lit("true_negatives"))
- .when(
- (pl.col("prediction_label") == pl.lit("predicted_positives"))
- & (pl.col("reals_labels") == pl.lit("real_competing"))
- & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative"))
- )
- .then(pl.lit("false_positives"))
- .otherwise(pl.lit("excluded")) # or pl.lit(None) if you prefer nulls
- .cast(
- pl.Enum(
- [
- "true_positives",
- "false_positives",
- "true_negatives",
- "false_negatives",
- "excluded",
- ]
- )
- )
- ).alias("classification_outcome")
- )
- )
- return final_adjusted_data_polars
-
-
-def _censored_count(df: pl.DataFrame) -> pl.DataFrame:
- return (
- df.with_columns(
- ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 0))
- .cast(pl.Float64)
- .alias("is_censored")
- )
- .group_by(["strata", "fixed_time_horizon"])
- .agg(pl.col("is_censored").sum().alias("real_censored_est"))
- )
-
-
-def _competing_count(df: pl.DataFrame) -> pl.DataFrame:
- return (
- df.with_columns(
- ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 2))
- .cast(pl.Float64)
- .alias("is_competing")
- )
- .group_by(["strata", "fixed_time_horizon"])
- .agg(pl.col("is_competing").sum().alias("real_competing_est"))
- )
-
-
-def _aj_estimates_by_cutoff_per_horizon(
- df: pl.DataFrame,
- horizons: list[float],
- breaks: Sequence[float],
- stratified_by: Sequence[str],
-) -> pl.DataFrame:
- return pl.concat(
- [
- df.filter(pl.col("fixed_time_horizon") == h)
- .group_by("strata")
- .map_groups(
- lambda group: extract_aj_estimate_by_cutoffs(
- group, [h], breaks, stratified_by, full_event_table=False
- )
- )
- for h in horizons
- ],
- how="vertical",
- )
-
-
-def _aj_estimates_per_horizon(
- df: pl.DataFrame, horizons: list[float], full_event_table: bool
-) -> pl.DataFrame:
- return pl.concat(
- [
- df.filter(pl.col("fixed_time_horizon") == h)
- .group_by("strata")
- .map_groups(
- lambda group: extract_aj_estimate_for_strata(
- group, [h], full_event_table
- )
- )
- for h in horizons
- ],
- how="vertical",
- )
-
-
-def _aj_adjusted_events(
- reference_group_data: pl.DataFrame,
- breaks: Sequence[float],
- exploded: pl.DataFrame,
- censoring: str,
- competing: str,
- horizons: list[float],
- stratified_by: Sequence[str],
- full_event_table: bool = False,
- risk_set_scope: Sequence[str] = ["within_stratum"],
-) -> pl.DataFrame:
- strata_enum_dtype = reference_group_data.schema["strata"]
-
- # Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff
- if censoring == "adjusted" and competing == "adjusted_as_negative":
- if risk_set_scope == "within_stratum":
- adjusted = (
- reference_group_data.group_by("strata")
- .map_groups(
- lambda group: extract_aj_estimate_for_strata(
- group, horizons, full_event_table
- )
- )
- .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
- )
- # preserve the original enum dtype for 'strata' coming from reference_group_data
-
- adjusted = adjusted.with_columns(
- [
- pl.col("strata").cast(strata_enum_dtype),
- pl.lit(risk_set_scope)
- .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
- .alias("risk_set_scope"),
- ]
- )
-
- return adjusted
-
- elif risk_set_scope == "pooled_by_cutoff":
- adjusted = extract_aj_estimate_by_cutoffs(
- reference_group_data, horizons, breaks, stratified_by, full_event_table
- )
- adjusted = adjusted.with_columns(
- pl.lit(risk_set_scope)
- .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
- .alias("risk_set_scope")
- )
- return adjusted
-
- # Special-case: both excluded (faster branch in original)
- if censoring == "excluded" and competing == "excluded":
- non_censored_non_competing = exploded.filter(
- (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") == 1)
- )
-
- adjusted = _aj_estimates_per_horizon(
- non_censored_non_competing, horizons, full_event_table
- )
-
- adjusted = adjusted.with_columns(
- [
- pl.col("strata").cast(strata_enum_dtype),
- pl.lit(risk_set_scope)
- .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
- .alias("risk_set_scope"),
- ]
- ).join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
-
- return adjusted
-
- # Special-case: competing excluded (handled by filtering out competing events)
- if competing == "excluded":
- # Use exploded to apply filters that depend on fixed_time_horizon consistently
- non_competing = exploded.filter(
- (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2)
- ).with_columns(
- pl.when(pl.col("reals") == 2)
- .then(pl.lit(0))
- .otherwise(pl.col("reals"))
- .alias("reals")
- )
-
- if risk_set_scope == "within_stratum":
- adjusted = (
- _aj_estimates_per_horizon(non_competing, horizons, full_event_table)
- # .select(pl.exclude("real_competing_est"))
- .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
- )
-
- elif risk_set_scope == "pooled_by_cutoff":
- adjusted = extract_aj_estimate_by_cutoffs(
- non_competing, horizons, breaks, stratified_by, full_event_table
- )
-
- adjusted = adjusted.with_columns(
- [
- pl.col("strata").cast(strata_enum_dtype),
- pl.lit(risk_set_scope)
- .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
- .alias("risk_set_scope"),
- ]
- )
- return adjusted
-
- # For remaining cases, determine base dataframe depending on censoring rule:
- # - "adjusted": use the full reference_group_data (events censored at horizon are kept/adjusted)
- # - "excluded": remove administratively censored observations (use exploded with filter)
- base_df = (
- exploded.filter(
- (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") > 0)
- )
- if censoring == "excluded"
- else reference_group_data
- )
-
- # Apply competing-event transformation if required
- if competing == "adjusted_as_censored":
- base_df = base_df.with_columns(
- pl.when(pl.col("reals") == 2)
- .then(pl.lit(0))
- .otherwise(pl.col("reals"))
- .alias("reals")
- )
- elif competing == "adjusted_as_composite":
- base_df = base_df.with_columns(
- pl.when(pl.col("reals") == 2)
- .then(pl.lit(1))
- .otherwise(pl.col("reals"))
- .alias("reals")
- )
- # competing == "adjusted_as_negative": keep reals as-is (no transform)
-
- # Finally choose aggregation strategy: per-stratum or horizon-wise
- if censoring == "excluded":
- # For excluded censoring we always evaluate per-horizon on the filtered (exploded) dataset
-
- if risk_set_scope == "within_stratum":
- adjusted = _aj_estimates_per_horizon(base_df, horizons, full_event_table)
-
- adjusted = adjusted.join(
- pl.DataFrame({"chosen_cutoff": breaks}), how="cross"
- )
-
- elif risk_set_scope == "pooled_by_cutoff":
- adjusted = _aj_estimates_by_cutoff_per_horizon(
- base_df, horizons, breaks, stratified_by
- )
-
- adjusted = adjusted.with_columns(
- pl.lit(risk_set_scope)
- .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
- .alias("risk_set_scope")
- )
-
- return adjusted.with_columns(pl.col("strata").cast(strata_enum_dtype))
- else:
- # For adjusted censoring we aggregate within strata
-
- if risk_set_scope == "within_stratum":
- adjusted = (
- base_df.group_by("strata")
- .map_groups(
- lambda group: extract_aj_estimate_for_strata(
- group, horizons, full_event_table
- )
- )
- .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
- )
-
- elif risk_set_scope == "pooled_by_cutoff":
- adjusted = extract_aj_estimate_by_cutoffs(
- base_df, horizons, breaks, stratified_by, full_event_table
- )
-
- adjusted = adjusted.with_columns(
- [
- pl.col("strata").cast(strata_enum_dtype),
- pl.lit(risk_set_scope)
- .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
- .alias("risk_set_scope"),
- ]
- )
-
- return adjusted
-
-
-def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame:
- cumulative_aj_data = (
- aj_data.group_by(
- [
- "reference_group",
- "stratified_by",
- "chosen_cutoff",
- "classification_outcome",
- ]
- )
- .agg([pl.col("reals_estimate").sum()])
- .pivot(on="classification_outcome", values="reals_estimate")
- .with_columns(
- [
- pl.col(col).fill_null(0)
- for col in [
- "true_positives",
- "true_negatives",
- "false_positives",
- "false_negatives",
- ]
- ]
- )
- .with_columns(
- (pl.col("true_positives") + pl.col("false_positives")).alias(
- "predicted_positives"
- ),
- (pl.col("true_negatives") + pl.col("false_negatives")).alias(
- "predicted_negatives"
- ),
- (pl.col("true_positives") + pl.col("false_negatives")).alias(
- "real_positives"
- ),
- (pl.col("false_positives") + pl.col("true_negatives")).alias(
- "real_negatives"
- ),
- (
- pl.col("true_positives")
- + pl.col("true_negatives")
- + pl.col("false_positives")
- + pl.col("false_negatives")
- )
- .alias("n")
- .sum(),
- )
- .with_columns(
- (pl.col("true_positives") + pl.col("false_positives")).alias(
- "predicted_positives"
- ),
- (pl.col("true_negatives") + pl.col("false_negatives")).alias(
- "predicted_negatives"
- ),
- (pl.col("true_positives") + pl.col("false_negatives")).alias(
- "real_positives"
- ),
- (pl.col("false_positives") + pl.col("true_negatives")).alias(
- "real_negatives"
- ),
- (
- pl.col("true_positives")
- + pl.col("true_negatives")
- + pl.col("false_positives")
- + pl.col("false_negatives")
- ).alias("n"),
- )
- )
-
- return cumulative_aj_data
-
-
-def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame:
- cumulative_aj_data = (
- aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff")
- .group_by(
- [
- "reference_group",
- "fixed_time_horizon",
- "censoring_heuristic",
- "competing_heuristic",
- "stratified_by",
- "chosen_cutoff",
- "classification_outcome",
- ]
- )
- .agg([pl.col("reals_estimate").sum()])
- .pivot(on="classification_outcome", values="reals_estimate")
- .fill_null(0)
- .with_columns(
- (pl.col("true_positives") + pl.col("false_positives")).alias(
- "predicted_positives"
- ),
- (pl.col("true_negatives") + pl.col("false_negatives")).alias(
- "predicted_negatives"
- ),
- (pl.col("true_positives") + pl.col("false_negatives")).alias(
- "real_positives"
- ),
- (pl.col("false_positives") + pl.col("true_negatives")).alias(
- "real_negatives"
- ),
- (
- pl.col("true_positives")
- + pl.col("true_negatives")
- + pl.col("false_positives")
- + pl.col("false_negatives")
- ).alias("n"),
- )
- .with_columns(
- (pl.col("true_positives") + pl.col("false_positives")).alias(
- "predicted_positives"
- ),
- (pl.col("true_negatives") + pl.col("false_negatives")).alias(
- "predicted_negatives"
- ),
- (pl.col("true_positives") + pl.col("false_negatives")).alias(
- "real_positives"
- ),
- (pl.col("false_positives") + pl.col("true_negatives")).alias(
- "real_negatives"
- ),
- (
- pl.col("true_positives")
- + pl.col("true_negatives")
- + pl.col("false_positives")
- + pl.col("false_negatives")
- ).alias("n"),
- )
- )
-
- return cumulative_aj_data
-
-
-def _turn_cumulative_aj_to_performance_data(
- cumulative_aj_data: pl.DataFrame,
-) -> pl.DataFrame:
- performance_data = cumulative_aj_data.with_columns(
- (pl.col("true_positives") / pl.col("real_positives")).alias("sensitivity"),
- (pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"),
- (pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"),
- (pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"),
- (pl.col("false_positives") / pl.col("real_negatives")).alias(
- "false_positive_rate"
- ),
- (
- (pl.col("true_positives") / pl.col("predicted_positives"))
- / (pl.col("real_positives") / pl.col("n"))
- ).alias("lift"),
- pl.when(pl.col("stratified_by") == "probability_threshold")
- .then(
- (pl.col("true_positives") / pl.col("n"))
- - (pl.col("false_positives") / pl.col("n"))
- * pl.col("chosen_cutoff")
- / (1 - pl.col("chosen_cutoff"))
- )
- .otherwise(None)
- .alias("net_benefit"),
- pl.when(pl.col("stratified_by") == "probability_threshold")
- .then(
- 100 * (pl.col("true_negatives") / pl.col("n"))
- - (pl.col("false_negatives") / pl.col("n"))
- * (1 - pl.col("chosen_cutoff"))
- / pl.col("chosen_cutoff")
- )
- .otherwise(None)
- .alias("net_benefit_interventions_avoided"),
- pl.when(pl.col("stratified_by") == "probability_threshold")
- .then(pl.col("predicted_positives") / pl.col("n"))
- .otherwise(pl.col("chosen_cutoff"))
- .alias("ppcr"),
- )
-
- return performance_data
diff --git a/src/rtichoke/performance_data/performance_data.py b/src/rtichoke/performance_data/performance_data.py
index 3922723..8fa2d30 100644
--- a/src/rtichoke/performance_data/performance_data.py
+++ b/src/rtichoke/performance_data/performance_data.py
@@ -5,13 +5,15 @@
from typing import Dict, Union
import polars as pl
from collections.abc import Sequence
-from rtichoke.helpers.sandbox_observable_helpers import (
+from rtichoke.processing.adjustments import _create_adjusted_data_binary
+from rtichoke.processing.combinations import (
_create_aj_data_combinations_binary,
create_breaks_values,
- _create_list_data_to_adjust_binary,
- _create_adjusted_data_binary,
- _cast_and_join_adjusted_data_binary,
+)
+from rtichoke.processing.transforms import (
_calculate_cumulative_aj_data_binary,
+ _cast_and_join_adjusted_data_binary,
+ _create_list_data_to_adjust_binary,
_turn_cumulative_aj_to_performance_data,
)
import numpy as np
diff --git a/src/rtichoke/performance_data/performance_data_times.py b/src/rtichoke/performance_data/performance_data_times.py
index 901bd59..d1629b4 100644
--- a/src/rtichoke/performance_data/performance_data_times.py
+++ b/src/rtichoke/performance_data/performance_data_times.py
@@ -5,14 +5,16 @@
from typing import Dict, Union
import polars as pl
from collections.abc import Sequence
-from rtichoke.helpers.sandbox_observable_helpers import (
- create_breaks_values,
+from rtichoke.processing.adjustments import create_adjusted_data
+from rtichoke.processing.combinations import (
create_aj_data_combinations,
- _create_list_data_to_adjust,
- create_adjusted_data,
- cast_and_join_adjusted_data,
+ create_breaks_values,
+)
+from rtichoke.processing.transforms import (
_calculate_cumulative_aj_data,
+ _create_list_data_to_adjust,
_turn_cumulative_aj_to_performance_data,
+ cast_and_join_adjusted_data,
)
import numpy as np
diff --git a/src/rtichoke/helpers/__init__.py b/src/rtichoke/processing/__init__.py
similarity index 100%
rename from src/rtichoke/helpers/__init__.py
rename to src/rtichoke/processing/__init__.py
diff --git a/src/rtichoke/processing/adjustments.py b/src/rtichoke/processing/adjustments.py
new file mode 100644
index 0000000..4bb982b
--- /dev/null
+++ b/src/rtichoke/processing/adjustments.py
@@ -0,0 +1,743 @@
+import pandas as pd
+import polars as pl
+from polarstate import predict_aj_estimates
+from polarstate import prepare_event_table
+from collections.abc import Sequence
+from rtichoke.processing.transforms import assign_and_explode_polars
+
+
+def create_aj_data(
+ reference_group_data,
+ breaks,
+ censoring_heuristic,
+ competing_heuristic,
+ fixed_time_horizons,
+ stratified_by: Sequence[str],
+ full_event_table: bool = False,
+ risk_set_scope: Sequence[str] = ["within_stratum"],
+):
+ """
+ Create AJ estimates per strata based on censoring and competing heuristicss.
+ """
+
+ def aj_estimates_with_cross(df, extra_cols):
+ return df.join(pl.DataFrame(extra_cols), how="cross")
+
+ exploded = assign_and_explode_polars(reference_group_data, fixed_time_horizons)
+
+ event_table = prepare_event_table(reference_group_data)
+
+ # TODO: solve strata in the pipeline
+
+ excluded_events = _extract_excluded_events(
+ event_table, fixed_time_horizons, censoring_heuristic, competing_heuristic
+ )
+
+ aj_dfs = []
+ for rscope in risk_set_scope:
+ aj_res = _aj_adjusted_events(
+ reference_group_data,
+ breaks,
+ exploded,
+ censoring_heuristic,
+ competing_heuristic,
+ fixed_time_horizons,
+ stratified_by,
+ full_event_table,
+ rscope,
+ )
+
+ aj_res = aj_res.select(
+ [
+ "strata",
+ "times",
+ "chosen_cutoff",
+ "real_negatives_est",
+ "real_positives_est",
+ "real_competing_est",
+ "estimate_origin",
+ "fixed_time_horizon",
+ "risk_set_scope",
+ ]
+ )
+
+ aj_dfs.append(aj_res)
+
+ aj_df = pl.concat(aj_dfs, how="vertical")
+
+ result = aj_df.join(excluded_events, on=["fixed_time_horizon"], how="left")
+
+ return aj_estimates_with_cross(
+ result,
+ {
+ "censoring_heuristic": censoring_heuristic,
+ "competing_heuristic": competing_heuristic,
+ },
+ ).select(
+ [
+ "strata",
+ "chosen_cutoff",
+ "fixed_time_horizon",
+ "times",
+ "real_negatives_est",
+ "real_positives_est",
+ "real_competing_est",
+ "real_censored_est",
+ "censoring_heuristic",
+ "competing_heuristic",
+ "estimate_origin",
+ "risk_set_scope",
+ ]
+ )
+
+
+def _extract_excluded_events(
+ event_table: pl.DataFrame,
+ fixed_time_horizons: list[float],
+ censoring_heuristic: str,
+ competing_heuristic: str,
+) -> pl.DataFrame:
+ horizons_df = pl.DataFrame({"times": fixed_time_horizons}).sort("times")
+
+ excluded_events = horizons_df.join_asof(
+ event_table.with_columns(
+ pl.col("count_0").cum_sum().cast(pl.Float64).alias("real_censored_est"),
+ pl.col("count_2").cum_sum().cast(pl.Float64).alias("real_competing_est"),
+ ).select(
+ pl.col("times"),
+ pl.col("real_censored_est"),
+ pl.col("real_competing_est"),
+ ),
+ left_on="times",
+ right_on="times",
+ ).with_columns([pl.col("times").alias("fixed_time_horizon")])
+
+ if censoring_heuristic != "excluded":
+ excluded_events = excluded_events.with_columns(
+ pl.lit(0.0).alias("real_censored_est")
+ )
+
+ if competing_heuristic != "excluded":
+ excluded_events = excluded_events.with_columns(
+ pl.lit(0.0).alias("real_competing_est")
+ )
+
+ return excluded_events
+
+
+def extract_crude_estimate_polars(data: pl.DataFrame) -> pl.DataFrame:
+ all_combinations = data.select(["strata", "reals", "fixed_time_horizon"]).unique()
+
+ counts = data.group_by(["strata", "reals", "fixed_time_horizon"]).agg(
+ pl.count().alias("reals_estimate")
+ )
+
+ return all_combinations.join(
+ counts, on=["strata", "reals", "fixed_time_horizon"], how="left"
+ ).with_columns([pl.col("reals_estimate").fill_null(0).cast(pl.Int64)])
+
+
+def extract_aj_estimate_by_cutoffs(
+ data_to_adjust, horizons, breaks, stratified_by, full_event_table: bool
+):
+ # n = data_to_adjust.height
+
+ counts_per_strata = (
+ data_to_adjust.group_by(
+ ["strata", "stratified_by", "upper_bound", "lower_bound"]
+ )
+ .len(name="strata_count")
+ .with_columns(pl.col("strata_count").cast(pl.Float64))
+ )
+
+ aj_estimates_predicted_positives = pl.DataFrame()
+ aj_estimates_predicted_negatives = pl.DataFrame()
+
+ for stratification_criteria in stratified_by:
+ for chosen_cutoff in breaks:
+ if stratification_criteria == "probability_threshold":
+ mask_predicted_positives = (pl.col("upper_bound") > chosen_cutoff) & (
+ pl.col("stratified_by") == "probability_threshold"
+ )
+ mask_predicted_negatives = (pl.col("upper_bound") <= chosen_cutoff) & (
+ pl.col("stratified_by") == "probability_threshold"
+ )
+
+ elif stratification_criteria == "ppcr":
+ mask_predicted_positives = (
+ pl.col("lower_bound") > 1 - chosen_cutoff
+ ) & (pl.col("stratified_by") == "ppcr")
+ mask_predicted_negatives = (
+ pl.col("lower_bound") <= 1 - chosen_cutoff
+ ) & (pl.col("stratified_by") == "ppcr")
+
+ predicted_positives = data_to_adjust.filter(mask_predicted_positives)
+ predicted_negatives = data_to_adjust.filter(mask_predicted_negatives)
+
+ counts_per_strata_predicted_positives = counts_per_strata.filter(
+ mask_predicted_positives
+ )
+ counts_per_strata_predicted_negatives = counts_per_strata.filter(
+ mask_predicted_negatives
+ )
+
+ event_table_predicted_positives = prepare_event_table(predicted_positives)
+ event_table_predicted_negatives = prepare_event_table(predicted_negatives)
+
+ aj_estimate_predicted_positives = (
+ (
+ predict_aj_estimates(
+ event_table_predicted_positives,
+ pl.Series(horizons),
+ full_event_table,
+ )
+ .with_columns(
+ pl.lit(chosen_cutoff).alias("chosen_cutoff"),
+ pl.lit(stratification_criteria)
+ .alias("stratified_by")
+ .cast(pl.Enum(["probability_threshold", "ppcr"])),
+ )
+ .join(
+ counts_per_strata_predicted_positives,
+ on=["stratified_by"],
+ how="left",
+ )
+ .with_columns(
+ [
+ (
+ pl.col("state_occupancy_probability_0")
+ * pl.col("strata_count")
+ ).alias("real_negatives_est"),
+ (
+ pl.col("state_occupancy_probability_1")
+ * pl.col("strata_count")
+ ).alias("real_positives_est"),
+ (
+ pl.col("state_occupancy_probability_2")
+ * pl.col("strata_count")
+ ).alias("real_competing_est"),
+ ]
+ )
+ )
+ .select(
+ [
+ "strata",
+ # "stratified_by",
+ "times",
+ "chosen_cutoff",
+ "real_negatives_est",
+ "real_positives_est",
+ "real_competing_est",
+ "estimate_origin",
+ ]
+ )
+ .with_columns([pl.col("times").alias("fixed_time_horizon")])
+ )
+
+ aj_estimate_predicted_negatives = (
+ (
+ predict_aj_estimates(
+ event_table_predicted_negatives,
+ pl.Series(horizons),
+ full_event_table,
+ )
+ .with_columns(
+ pl.lit(chosen_cutoff).alias("chosen_cutoff"),
+ pl.lit(stratification_criteria)
+ .alias("stratified_by")
+ .cast(pl.Enum(["probability_threshold", "ppcr"])),
+ )
+ .join(
+ counts_per_strata_predicted_negatives,
+ on=["stratified_by"],
+ how="left",
+ )
+ .with_columns(
+ [
+ (
+ pl.col("state_occupancy_probability_0")
+ * pl.col("strata_count")
+ ).alias("real_negatives_est"),
+ (
+ pl.col("state_occupancy_probability_1")
+ * pl.col("strata_count")
+ ).alias("real_positives_est"),
+ (
+ pl.col("state_occupancy_probability_2")
+ * pl.col("strata_count")
+ ).alias("real_competing_est"),
+ ]
+ )
+ )
+ .select(
+ [
+ "strata",
+ # "stratified_by",
+ "times",
+ "chosen_cutoff",
+ "real_negatives_est",
+ "real_positives_est",
+ "real_competing_est",
+ "estimate_origin",
+ ]
+ )
+ .with_columns([pl.col("times").alias("fixed_time_horizon")])
+ )
+
+ aj_estimates_predicted_negatives = pl.concat(
+ [aj_estimates_predicted_negatives, aj_estimate_predicted_negatives],
+ how="vertical",
+ )
+
+ aj_estimates_predicted_positives = pl.concat(
+ [aj_estimates_predicted_positives, aj_estimate_predicted_positives],
+ how="vertical",
+ )
+
+ aj_estimate_by_cutoffs = pl.concat(
+ [aj_estimates_predicted_negatives, aj_estimates_predicted_positives],
+ how="vertical",
+ )
+
+ return aj_estimate_by_cutoffs
+
+
+def extract_aj_estimate_for_strata(data_to_adjust, horizons, full_event_table: bool):
+ n = data_to_adjust.height
+
+ event_table = prepare_event_table(data_to_adjust)
+
+ aj_estimate_for_strata_polars = predict_aj_estimates(
+ event_table, pl.Series(horizons), full_event_table
+ )
+
+ if len(horizons) == 1:
+ aj_estimate_for_strata_polars = aj_estimate_for_strata_polars.with_columns(
+ pl.lit(horizons[0]).alias("fixed_time_horizon")
+ )
+
+ else:
+ fixed_df = aj_estimate_for_strata_polars.filter(
+ pl.col("estimate_origin") == "fixed_time_horizons"
+ ).with_columns([pl.col("times").alias("fixed_time_horizon")])
+
+ event_df = (
+ aj_estimate_for_strata_polars.filter(
+ pl.col("estimate_origin") == "event_table"
+ )
+ .with_columns([pl.lit(horizons).alias("fixed_time_horizon")])
+ .explode("fixed_time_horizon")
+ )
+
+ aj_estimate_for_strata_polars = pl.concat(
+ [fixed_df, event_df], how="vertical"
+ ).sort("estimate_origin", "fixed_time_horizon", "times")
+
+ return aj_estimate_for_strata_polars.with_columns(
+ [
+ (pl.col("state_occupancy_probability_0") * n).alias("real_negatives_est"),
+ (pl.col("state_occupancy_probability_1") * n).alias("real_positives_est"),
+ (pl.col("state_occupancy_probability_2") * n).alias("real_competing_est"),
+ pl.col("fixed_time_horizon").cast(pl.Float64),
+ pl.lit(data_to_adjust["strata"][0]).alias("strata"),
+ ]
+ ).select(
+ [
+ "strata",
+ "times",
+ "fixed_time_horizon",
+ "real_negatives_est",
+ "real_positives_est",
+ "real_competing_est",
+ pl.col("estimate_origin"),
+ ]
+ )
+
+
+def ensure_no_categorical(df: pd.DataFrame) -> pd.DataFrame:
+ df = df.copy()
+ for col in df.select_dtypes(include="category").columns:
+ df[col] = df[col].astype(str)
+ return df
+
+
+def extract_aj_estimate_by_heuristics(
+ df: pl.DataFrame,
+ breaks: Sequence[float],
+ heuristics_sets: list[dict],
+ fixed_time_horizons: list[float],
+ stratified_by: Sequence[str],
+ risk_set_scope: Sequence[str] = ["within_stratum"],
+) -> pl.DataFrame:
+ aj_dfs = []
+
+ for heuristic in heuristics_sets:
+ censoring = heuristic["censoring_heuristic"]
+ competing = heuristic["competing_heuristic"]
+
+ aj_df = create_aj_data(
+ df,
+ breaks,
+ censoring,
+ competing,
+ fixed_time_horizons,
+ stratified_by=stratified_by,
+ full_event_table=False,
+ risk_set_scope=risk_set_scope,
+ ).with_columns(
+ [
+ pl.lit(censoring).alias("censoring_heuristic"),
+ pl.lit(competing).alias("competing_heuristic"),
+ ]
+ )
+
+ aj_dfs.append(aj_df)
+
+ aj_estimates_data = pl.concat(aj_dfs).drop(["estimate_origin", "times"])
+
+ aj_estimates_unpivoted = aj_estimates_data.unpivot(
+ index=[
+ "strata",
+ "chosen_cutoff",
+ "fixed_time_horizon",
+ "censoring_heuristic",
+ "competing_heuristic",
+ "risk_set_scope",
+ ],
+ variable_name="reals_labels",
+ value_name="reals_estimate",
+ )
+
+ return aj_estimates_unpivoted
+
+
+def _create_adjusted_data_binary(
+ list_data_to_adjust: dict[str, pl.DataFrame],
+ breaks: Sequence[float],
+ stratified_by: Sequence[str],
+) -> pl.DataFrame:
+ long_df = pl.concat(list(list_data_to_adjust.values()), how="vertical")
+
+ adjusted_data_binary = (
+ long_df.group_by(["strata", "stratified_by", "reference_group", "reals_labels"])
+ .agg(pl.count().alias("reals_estimate"))
+ .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
+ )
+
+ return adjusted_data_binary
+
+
+def create_adjusted_data(
+ list_data_to_adjust: dict[str, pl.DataFrame],
+ heuristics_sets: list[dict[str, str]],
+ fixed_time_horizons: list[float],
+ breaks: Sequence[float],
+ stratified_by: Sequence[str],
+ risk_set_scope: Sequence[str] = ["within_stratum"],
+) -> pl.DataFrame:
+ all_results = []
+
+ reference_groups = list(list_data_to_adjust.keys())
+ reference_group_enum = pl.Enum(reference_groups)
+
+ heuristics_df = pl.DataFrame(heuristics_sets)
+ censoring_heuristic_enum = pl.Enum(
+ heuristics_df["censoring_heuristic"].unique(maintain_order=True)
+ )
+ competing_heuristic_enum = pl.Enum(
+ heuristics_df["competing_heuristic"].unique(maintain_order=True)
+ )
+
+ for reference_group, df in list_data_to_adjust.items():
+ input_df = df.select(
+ ["strata", "reals", "times", "upper_bound", "lower_bound", "stratified_by"]
+ )
+
+ aj_result = extract_aj_estimate_by_heuristics(
+ input_df,
+ breaks,
+ heuristics_sets=heuristics_sets,
+ fixed_time_horizons=fixed_time_horizons,
+ stratified_by=stratified_by,
+ risk_set_scope=risk_set_scope,
+ )
+
+ aj_result_with_group = aj_result.with_columns(
+ [
+ pl.lit(reference_group)
+ .cast(reference_group_enum)
+ .alias("reference_group")
+ ]
+ )
+
+ all_results.append(aj_result_with_group)
+
+ reals_enum_dtype = pl.Enum(
+ [
+ "real_negatives",
+ "real_positives",
+ "real_competing",
+ "real_censored",
+ ]
+ )
+
+ return (
+ pl.concat(all_results)
+ .with_columns([pl.col("reference_group").cast(reference_group_enum)])
+ .with_columns(
+ [
+ pl.col("reals_labels").str.replace(r"_est$", "").cast(reals_enum_dtype),
+ pl.col("censoring_heuristic").cast(censoring_heuristic_enum),
+ pl.col("competing_heuristic").cast(competing_heuristic_enum),
+ ]
+ )
+ )
+
+
+def _censored_count(df: pl.DataFrame) -> pl.DataFrame:
+ return (
+ df.with_columns(
+ ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 0))
+ .cast(pl.Float64)
+ .alias("is_censored")
+ )
+ .group_by(["strata", "fixed_time_horizon"])
+ .agg(pl.col("is_censored").sum().alias("real_censored_est"))
+ )
+
+
+def _competing_count(df: pl.DataFrame) -> pl.DataFrame:
+ return (
+ df.with_columns(
+ ((pl.col("times") <= pl.col("fixed_time_horizon")) & (pl.col("reals") == 2))
+ .cast(pl.Float64)
+ .alias("is_competing")
+ )
+ .group_by(["strata", "fixed_time_horizon"])
+ .agg(pl.col("is_competing").sum().alias("real_competing_est"))
+ )
+
+
+def _aj_estimates_by_cutoff_per_horizon(
+ df: pl.DataFrame,
+ horizons: list[float],
+ breaks: Sequence[float],
+ stratified_by: Sequence[str],
+) -> pl.DataFrame:
+ return pl.concat(
+ [
+ df.filter(pl.col("fixed_time_horizon") == h)
+ .group_by("strata")
+ .map_groups(
+ lambda group: extract_aj_estimate_by_cutoffs(
+ group, [h], breaks, stratified_by, full_event_table=False
+ )
+ )
+ for h in horizons
+ ],
+ how="vertical",
+ )
+
+
+def _aj_estimates_per_horizon(
+ df: pl.DataFrame, horizons: list[float], full_event_table: bool
+) -> pl.DataFrame:
+ return pl.concat(
+ [
+ df.filter(pl.col("fixed_time_horizon") == h)
+ .group_by("strata")
+ .map_groups(
+ lambda group: extract_aj_estimate_for_strata(
+ group, [h], full_event_table
+ )
+ )
+ for h in horizons
+ ],
+ how="vertical",
+ )
+
+
+def _aj_adjusted_events(
+ reference_group_data: pl.DataFrame,
+ breaks: Sequence[float],
+ exploded: pl.DataFrame,
+ censoring: str,
+ competing: str,
+ horizons: list[float],
+ stratified_by: Sequence[str],
+ full_event_table: bool = False,
+ risk_set_scope: Sequence[str] = ["within_stratum"],
+) -> pl.DataFrame:
+ strata_enum_dtype = reference_group_data.schema["strata"]
+
+ # Special-case: adjusted censoring + competing adjusted_as_negative supports pooled_by_cutoff
+ if censoring == "adjusted" and competing == "adjusted_as_negative":
+ if risk_set_scope == "within_stratum":
+ adjusted = (
+ reference_group_data.group_by("strata")
+ .map_groups(
+ lambda group: extract_aj_estimate_for_strata(
+ group, horizons, full_event_table
+ )
+ )
+ .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
+ )
+ # preserve the original enum dtype for 'strata' coming from reference_group_data
+
+ adjusted = adjusted.with_columns(
+ [
+ pl.col("strata").cast(strata_enum_dtype),
+ pl.lit(risk_set_scope)
+ .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
+ .alias("risk_set_scope"),
+ ]
+ )
+
+ return adjusted
+
+ elif risk_set_scope == "pooled_by_cutoff":
+ adjusted = extract_aj_estimate_by_cutoffs(
+ reference_group_data, horizons, breaks, stratified_by, full_event_table
+ )
+ adjusted = adjusted.with_columns(
+ pl.lit(risk_set_scope)
+ .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
+ .alias("risk_set_scope")
+ )
+ return adjusted
+
+ # Special-case: both excluded (faster branch in original)
+ if censoring == "excluded" and competing == "excluded":
+ non_censored_non_competing = exploded.filter(
+ (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") == 1)
+ )
+
+ adjusted = _aj_estimates_per_horizon(
+ non_censored_non_competing, horizons, full_event_table
+ )
+
+ adjusted = adjusted.with_columns(
+ [
+ pl.col("strata").cast(strata_enum_dtype),
+ pl.lit(risk_set_scope)
+ .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
+ .alias("risk_set_scope"),
+ ]
+ ).join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
+
+ return adjusted
+
+ # Special-case: competing excluded (handled by filtering out competing events)
+ if competing == "excluded":
+ # Use exploded to apply filters that depend on fixed_time_horizon consistently
+ non_competing = exploded.filter(
+ (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") != 2)
+ ).with_columns(
+ pl.when(pl.col("reals") == 2)
+ .then(pl.lit(0))
+ .otherwise(pl.col("reals"))
+ .alias("reals")
+ )
+
+ if risk_set_scope == "within_stratum":
+ adjusted = (
+ _aj_estimates_per_horizon(non_competing, horizons, full_event_table)
+ # .select(pl.exclude("real_competing_est"))
+ .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
+ )
+
+ elif risk_set_scope == "pooled_by_cutoff":
+ adjusted = extract_aj_estimate_by_cutoffs(
+ non_competing, horizons, breaks, stratified_by, full_event_table
+ )
+
+ adjusted = adjusted.with_columns(
+ [
+ pl.col("strata").cast(strata_enum_dtype),
+ pl.lit(risk_set_scope)
+ .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
+ .alias("risk_set_scope"),
+ ]
+ )
+ return adjusted
+
+ # For remaining cases, determine base dataframe depending on censoring rule:
+ # - "adjusted": use the full reference_group_data (events censored at horizon are kept/adjusted)
+ # - "excluded": remove administratively censored observations (use exploded with filter)
+ base_df = (
+ exploded.filter(
+ (pl.col("times") > pl.col("fixed_time_horizon")) | (pl.col("reals") > 0)
+ )
+ if censoring == "excluded"
+ else reference_group_data
+ )
+
+ # Apply competing-event transformation if required
+ if competing == "adjusted_as_censored":
+ base_df = base_df.with_columns(
+ pl.when(pl.col("reals") == 2)
+ .then(pl.lit(0))
+ .otherwise(pl.col("reals"))
+ .alias("reals")
+ )
+ elif competing == "adjusted_as_composite":
+ base_df = base_df.with_columns(
+ pl.when(pl.col("reals") == 2)
+ .then(pl.lit(1))
+ .otherwise(pl.col("reals"))
+ .alias("reals")
+ )
+ # competing == "adjusted_as_negative": keep reals as-is (no transform)
+
+ # Finally choose aggregation strategy: per-stratum or horizon-wise
+ if censoring == "excluded":
+ # For excluded censoring we always evaluate per-horizon on the filtered (exploded) dataset
+
+ if risk_set_scope == "within_stratum":
+ adjusted = _aj_estimates_per_horizon(base_df, horizons, full_event_table)
+
+ adjusted = adjusted.join(
+ pl.DataFrame({"chosen_cutoff": breaks}), how="cross"
+ )
+
+ elif risk_set_scope == "pooled_by_cutoff":
+ adjusted = _aj_estimates_by_cutoff_per_horizon(
+ base_df, horizons, breaks, stratified_by
+ )
+
+ adjusted = adjusted.with_columns(
+ pl.lit(risk_set_scope)
+ .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
+ .alias("risk_set_scope")
+ )
+
+ return adjusted.with_columns(pl.col("strata").cast(strata_enum_dtype))
+ else:
+ # For adjusted censoring we aggregate within strata
+
+ if risk_set_scope == "within_stratum":
+ adjusted = (
+ base_df.group_by("strata")
+ .map_groups(
+ lambda group: extract_aj_estimate_for_strata(
+ group, horizons, full_event_table
+ )
+ )
+ .join(pl.DataFrame({"chosen_cutoff": breaks}), how="cross")
+ )
+
+ elif risk_set_scope == "pooled_by_cutoff":
+ adjusted = extract_aj_estimate_by_cutoffs(
+ base_df, horizons, breaks, stratified_by, full_event_table
+ )
+
+ adjusted = adjusted.with_columns(
+ [
+ pl.col("strata").cast(strata_enum_dtype),
+ pl.lit(risk_set_scope)
+ .cast(pl.Enum(["within_stratum", "pooled_by_cutoff"]))
+ .alias("risk_set_scope"),
+ ]
+ )
+
+ return adjusted
diff --git a/src/rtichoke/processing/combinations.py b/src/rtichoke/processing/combinations.py
new file mode 100644
index 0000000..b790929
--- /dev/null
+++ b/src/rtichoke/processing/combinations.py
@@ -0,0 +1,218 @@
+import numpy as np
+import polars as pl
+from typing import Dict
+from collections.abc import Sequence
+
+
+def _enum_dataframe(column_name: str, values: Sequence[str]) -> pl.DataFrame:
+ """Create a single-column DataFrame with an enum dtype."""
+ enum_values = list(dict.fromkeys(values))
+ enum_dtype = pl.Enum(enum_values)
+ return pl.DataFrame({column_name: pl.Series(values, dtype=enum_dtype)})
+
+
+def create_strata_combinations(stratified_by: str, by: float, breaks) -> pl.DataFrame:
+ s_by = str(by)
+ decimals = len(s_by.split(".")[-1]) if "." in s_by else 0
+ fmt = f"{{:.{decimals}f}}"
+
+ if stratified_by == "probability_threshold":
+ upper_bound = breaks[1:] # breaks
+ lower_bound = breaks[:-1] # np.roll(upper_bound, 1)
+ # lower_bound[0] = 0.0
+ mid_point = upper_bound - by / 2
+ include_lower_bound = lower_bound > -0.1
+ include_upper_bound = upper_bound == 1.0 # upper_bound != 0.0
+ # chosen_cutoff = upper_bound
+ strata = format_strata_column(
+ lower_bound=lower_bound,
+ upper_bound=upper_bound,
+ include_lower_bound=include_lower_bound,
+ include_upper_bound=include_upper_bound,
+ decimals=2,
+ )
+
+ elif stratified_by == "ppcr":
+ strata_mid = breaks[1:]
+ lower_bound = strata_mid - by / 2
+ upper_bound = strata_mid + by / 2
+ mid_point = breaks[1:]
+ include_lower_bound = np.ones_like(strata_mid, dtype=bool)
+ include_upper_bound = np.zeros_like(strata_mid, dtype=bool)
+ # chosen_cutoff = strata_mid
+ strata = np.array([fmt.format(x) for x in strata_mid], dtype=object)
+ else:
+ raise ValueError(f"Unsupported stratified_by: {stratified_by}")
+
+ bins_df = pl.DataFrame(
+ {
+ "strata": pl.Series(strata),
+ "lower_bound": lower_bound,
+ "upper_bound": upper_bound,
+ "mid_point": mid_point,
+ "include_lower_bound": include_lower_bound,
+ "include_upper_bound": include_upper_bound,
+ # "chosen_cutoff": chosen_cutoff,
+ "stratified_by": [stratified_by] * len(strata),
+ }
+ )
+
+ cutoffs_df = pl.DataFrame({"chosen_cutoff": breaks})
+
+ return bins_df.join(cutoffs_df, how="cross")
+
+
+def format_strata_column(
+ lower_bound: list[float],
+ upper_bound: list[float],
+ include_lower_bound: list[bool],
+ include_upper_bound: list[bool],
+ decimals: int = 3,
+) -> list[str]:
+ return [
+ f"{'[' if ilb else '('}"
+ f"{round(lb, decimals):.{decimals}f}, "
+ f"{round(ub, decimals):.{decimals}f}"
+ f"{']' if iub else ')'}"
+ for lb, ub, ilb, iub in zip(
+ lower_bound, upper_bound, include_lower_bound, include_upper_bound
+ )
+ ]
+
+
+def format_strata_interval(
+ lower: float, upper: float, include_lower: bool, include_upper: bool
+) -> str:
+ left = "[" if include_lower else "("
+ right = "]" if include_upper else ")"
+ return f"{left}{lower:.3f}, {upper:.3f}{right}"
+
+
+def create_breaks_values(probs_vec, stratified_by, by):
+ if stratified_by != "probability_threshold":
+ breaks = np.quantile(probs_vec, np.linspace(1, 0, int(1 / by) + 1))
+ else:
+ breaks = np.round(
+ np.arange(0, 1 + by, by), decimals=len(str(by).split(".")[-1])
+ )
+ return breaks
+
+
+def _create_aj_data_combinations_binary(
+ reference_groups: Sequence[str],
+ stratified_by: Sequence[str],
+ by: float,
+ breaks: Sequence[float],
+) -> pl.DataFrame:
+ dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by]
+
+ strata_combinations = pl.concat(dfs, how="vertical")
+
+ strata_cats = (
+ strata_combinations.select(pl.col("strata").unique(maintain_order=True))
+ .to_series()
+ .to_list()
+ )
+
+ strata_enum = pl.Enum(strata_cats)
+ stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"])
+
+ strata_combinations = strata_combinations.with_columns(
+ [
+ pl.col("strata").cast(strata_enum),
+ pl.col("stratified_by").cast(stratified_by_enum),
+ ]
+ )
+
+ # Define values for Cartesian product
+ reals_labels = ["real_negatives", "real_positives"]
+
+ combinations_frames: list[pl.DataFrame] = [
+ _enum_dataframe("reference_group", reference_groups),
+ strata_combinations,
+ _enum_dataframe("reals_labels", reals_labels),
+ ]
+
+ result = combinations_frames[0]
+ for frame in combinations_frames[1:]:
+ result = result.join(frame, how="cross")
+
+ return result
+
+
+def create_aj_data_combinations(
+ reference_groups: Sequence[str],
+ heuristics_sets: list[Dict],
+ fixed_time_horizons: Sequence[float],
+ stratified_by: Sequence[str],
+ by: float,
+ breaks: Sequence[float],
+ risk_set_scope: Sequence[str] = ["within_stratum", "pooled_by_cutoff"],
+) -> pl.DataFrame:
+ dfs = [create_strata_combinations(sb, by, breaks) for sb in stratified_by]
+ strata_combinations = pl.concat(dfs, how="vertical")
+
+ # strata_enum = pl.Enum(strata_combinations["strata"])
+
+ strata_cats = (
+ strata_combinations.select(pl.col("strata").unique(maintain_order=True))
+ .to_series()
+ .to_list()
+ )
+
+ strata_enum = pl.Enum(strata_cats)
+ stratified_by_enum = pl.Enum(["probability_threshold", "ppcr"])
+
+ strata_combinations = strata_combinations.with_columns(
+ [
+ pl.col("strata").cast(strata_enum),
+ pl.col("stratified_by").cast(stratified_by_enum),
+ ]
+ )
+
+ risk_set_scope_combinations = pl.DataFrame(
+ {
+ "risk_set_scope": pl.Series(risk_set_scope).cast(
+ pl.Enum(["within_stratum", "pooled_by_cutoff"])
+ )
+ }
+ )
+
+ # Define values for Cartesian product
+ reals_labels = [
+ "real_negatives",
+ "real_positives",
+ "real_competing",
+ "real_censored",
+ ]
+
+ heuristics_combinations = pl.DataFrame(heuristics_sets)
+
+ censoring_heuristics_enum = pl.Enum(
+ heuristics_combinations["censoring_heuristic"].unique(maintain_order=True)
+ )
+ competing_heuristics_enum = pl.Enum(
+ heuristics_combinations["competing_heuristic"].unique(maintain_order=True)
+ )
+
+ combinations_frames: list[pl.DataFrame] = [
+ _enum_dataframe("reference_group", reference_groups),
+ pl.DataFrame(
+ {"fixed_time_horizon": pl.Series(fixed_time_horizons, dtype=pl.Float64)}
+ ),
+ heuristics_combinations.with_columns(
+ [
+ pl.col("censoring_heuristic").cast(censoring_heuristics_enum),
+ pl.col("competing_heuristic").cast(competing_heuristics_enum),
+ ]
+ ),
+ strata_combinations,
+ risk_set_scope_combinations,
+ _enum_dataframe("reals_labels", reals_labels),
+ ]
+
+ result = combinations_frames[0]
+ for frame in combinations_frames[1:]:
+ result = result.join(frame, how="cross")
+
+ return result
diff --git a/src/rtichoke/helpers/exported_functions.py b/src/rtichoke/processing/exported_functions.py
similarity index 99%
rename from src/rtichoke/helpers/exported_functions.py
rename to src/rtichoke/processing/exported_functions.py
index a273346..778ad91 100644
--- a/src/rtichoke/helpers/exported_functions.py
+++ b/src/rtichoke/processing/exported_functions.py
@@ -4,7 +4,7 @@
import plotly.graph_objects as go
-from rtichoke.helpers.plotly_helper_functions import (
+from rtichoke.processing.plotly_helper_functions import (
create_non_interactive_curve,
create_interactive_marker,
create_reference_lines_for_plotly,
diff --git a/src/rtichoke/helpers/plotly_helper_functions.py b/src/rtichoke/processing/plotly_helper_functions.py
similarity index 99%
rename from src/rtichoke/helpers/plotly_helper_functions.py
rename to src/rtichoke/processing/plotly_helper_functions.py
index ab475be..074fc52 100644
--- a/src/rtichoke/helpers/plotly_helper_functions.py
+++ b/src/rtichoke/processing/plotly_helper_functions.py
@@ -681,6 +681,15 @@ def _htext(title: pl.Expr) -> pl.Expr:
(pl.col("x") >= min_p_threshold) & (pl.col("x") <= max_p_threshold)
)
+ return pl.DataFrame(
+ schema={
+ "reference_group": pl.Utf8,
+ "x": pl.Float64,
+ "y": pl.Float64,
+ "text": pl.Utf8,
+ }
+ )
+
def create_non_interactive_curve_polars(
performance_data_ready_for_curve, reference_group_color, reference_group
@@ -1157,7 +1166,7 @@ def _add_hover_text_to_performance_data(
)
return performance_data.with_columns(
- [pl.col(pl.FLOAT_DTYPES).round(3), hover_text_expr.alias("text")]
+ [pl.col(pl.Float64).round(3), hover_text_expr.alias("text")]
)
diff --git a/src/rtichoke/helpers/send_post_request_to_r_rtichoke.py b/src/rtichoke/processing/send_post_request_to_r_rtichoke.py
similarity index 98%
rename from src/rtichoke/helpers/send_post_request_to_r_rtichoke.py
rename to src/rtichoke/processing/send_post_request_to_r_rtichoke.py
index aaca6f3..f8254e6 100644
--- a/src/rtichoke/helpers/send_post_request_to_r_rtichoke.py
+++ b/src/rtichoke/processing/send_post_request_to_r_rtichoke.py
@@ -4,7 +4,7 @@
# import requests
import pandas as pd
-from rtichoke.helpers.exported_functions import create_plotly_curve
+from rtichoke.processing.exported_functions import create_plotly_curve
def send_requests_to_rtichoke_r(dictionary_to_send, url_api, endpoint):
diff --git a/src/rtichoke/processing/transforms.py b/src/rtichoke/processing/transforms.py
new file mode 100644
index 0000000..4e4339a
--- /dev/null
+++ b/src/rtichoke/processing/transforms.py
@@ -0,0 +1,700 @@
+import numpy as np
+import polars as pl
+from typing import Dict, Union
+from rtichoke.processing.combinations import create_breaks_values
+
+
+def add_cutoff_strata(data: pl.DataFrame, by: float, stratified_by) -> pl.DataFrame:
+ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame:
+ probs = group["probs"].to_numpy()
+ columns_to_add = []
+
+ breaks = create_breaks_values(probs, "probability_threshold", by)
+ if "probability_threshold" in stratified_by:
+ last_bin_index = len(breaks) - 2
+
+ bin_indices = np.digitize(probs, bins=breaks, right=False) - 1
+ bin_indices = np.where(probs == 1.0, last_bin_index, bin_indices)
+
+ lower_bounds = breaks[bin_indices]
+ upper_bounds = breaks[bin_indices + 1]
+
+ include_upper_bounds = bin_indices == last_bin_index
+
+ strata_prob_labels = np.where(
+ include_upper_bounds,
+ [f"[{lo:.2f}, {hi:.2f}]" for lo, hi in zip(lower_bounds, upper_bounds)],
+ [f"[{lo:.2f}, {hi:.2f})" for lo, hi in zip(lower_bounds, upper_bounds)],
+ ).astype(str)
+
+ columns_to_add.append(
+ pl.Series("strata_probability_threshold", strata_prob_labels)
+ )
+
+ if "ppcr" in stratified_by:
+ # --- Compute strata_ppcr as equal-frequency quantile bins by rank ---
+ by = float(by)
+ q = int(round(1 / by)) # e.g. 0.2 -> 5 bins
+
+ probs = np.asarray(probs, float)
+
+ edges = np.quantile(probs, np.linspace(0.0, 1.0, q + 1), method="linear")
+
+ edges = np.maximum.accumulate(edges)
+
+ edges[0] = 0.0
+ edges[-1] = 1.0
+
+ bin_idx = np.digitize(probs, bins=edges[1:-1], right=True)
+
+ s = str(by)
+ decimals = len(s.split(".")[-1]) if "." in s else 0
+
+ labels = [f"{x:.{decimals}f}" for x in np.linspace(by, 1.0, q)]
+
+ strata_labels = np.array(labels)[bin_idx]
+
+ columns_to_add.append(
+ pl.Series("strata_ppcr", strata_labels).cast(pl.Enum(labels))
+ )
+ return group.with_columns(columns_to_add)
+
+ # Apply per-group transformation
+ grouped = data.partition_by("reference_group", as_dict=True)
+ transformed_groups = [transform_group(group, by) for group in grouped.values()]
+ return pl.concat(transformed_groups)
+
+
+def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame:
+ # Identify id_vars and value_vars
+ id_vars = [col for col in data.columns if not col.startswith("strata_")]
+ value_vars = [col for col in data.columns if col.startswith("strata_")]
+
+ # Perform the melt (equivalent to pandas.melt)
+ data_long = data.melt(
+ id_vars=id_vars,
+ value_vars=value_vars,
+ variable_name="stratified_by",
+ value_name="strata",
+ )
+
+ stratified_by_labels = ["probability_threshold", "ppcr"]
+ stratified_by_enum = pl.Enum(stratified_by_labels)
+
+ # Remove "strata_" prefix from the 'stratified_by' column
+ data_long = data_long.with_columns(
+ pl.col("stratified_by").str.replace("^strata_", "").cast(stratified_by_enum)
+ )
+
+ return data_long
+
+
+def map_reals_to_labels_polars(data: pl.DataFrame) -> pl.DataFrame:
+ return data.with_columns(
+ [
+ pl.when(pl.col("reals") == 0)
+ .then("real_negatives")
+ .when(pl.col("reals") == 1)
+ .then("real_positives")
+ .when(pl.col("reals") == 2)
+ .then("real_competing")
+ .otherwise("real_censored")
+ .alias("reals")
+ ]
+ )
+
+
+def update_administrative_censoring_polars(data: pl.DataFrame) -> pl.DataFrame:
+ data = data.with_columns(
+ [
+ pl.when(
+ (pl.col("times") > pl.col("fixed_time_horizon"))
+ & (pl.col("reals_labels") == "real_positives")
+ )
+ .then(pl.lit("real_negatives"))
+ .when(
+ (pl.col("times") < pl.col("fixed_time_horizon"))
+ & (pl.col("reals_labels") == "real_negatives")
+ )
+ .then(pl.lit("real_censored"))
+ .otherwise(pl.col("reals_labels"))
+ .alias("reals_labels")
+ ]
+ )
+
+ return data
+
+
+def assign_and_explode_polars(
+ data: pl.DataFrame, fixed_time_horizons: list[float]
+) -> pl.DataFrame:
+ return (
+ data.with_columns(pl.lit(fixed_time_horizons).alias("fixed_time_horizon"))
+ .explode("fixed_time_horizon")
+ .with_columns(pl.col("fixed_time_horizon").cast(pl.Float64))
+ )
+
+
+def _create_list_data_to_adjust_binary(
+ aj_data_combinations: pl.DataFrame,
+ probs_dict: Dict[str, np.ndarray],
+ reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
+ stratified_by,
+ by,
+) -> Dict[str, pl.DataFrame]:
+ reference_group_labels = list(probs_dict.keys())
+
+ if isinstance(reals_dict, dict):
+ num_keys_reals = len(reals_dict)
+ else:
+ num_keys_reals = 1
+
+ reference_group_enum = pl.Enum(reference_group_labels)
+
+ strata_enum_dtype = aj_data_combinations.schema["strata"]
+
+ if len(probs_dict) == 1:
+ probs_array = np.asarray(probs_dict[reference_group_labels[0]])
+
+ data_to_adjust = pl.DataFrame(
+ {
+ "reference_group": np.repeat(reference_group_labels, len(probs_array)),
+ "probs": probs_array,
+ "reals": reals_dict,
+ }
+ ).with_columns(pl.col("reference_group").cast(reference_group_enum))
+
+ elif num_keys_reals == 1:
+ data_to_adjust = pl.DataFrame(
+ {
+ "reference_group": np.repeat(reference_group_labels, len(reals_dict)),
+ "probs": np.concatenate(
+ [probs_dict[group] for group in reference_group_labels]
+ ),
+ "reals": np.tile(np.asarray(reals_dict), len(reference_group_labels)),
+ }
+ ).with_columns(pl.col("reference_group").cast(reference_group_enum))
+
+ elif isinstance(reals_dict, dict):
+ data_to_adjust = (
+ pl.DataFrame(
+ {
+ "reference_group": list(probs_dict.keys()),
+ "probs": list(probs_dict.values()),
+ "reals": list(reals_dict.values()),
+ }
+ )
+ .explode(["probs", "reals"])
+ .with_columns(pl.col("reference_group").cast(reference_group_enum))
+ )
+
+ data_to_adjust = add_cutoff_strata(
+ data_to_adjust, by=by, stratified_by=stratified_by
+ )
+
+ data_to_adjust = pivot_longer_strata(data_to_adjust)
+
+ data_to_adjust = (
+ data_to_adjust.with_columns([pl.col("strata")])
+ .with_columns(pl.col("strata").cast(strata_enum_dtype))
+ .join(
+ aj_data_combinations.select(
+ pl.col("strata"),
+ pl.col("stratified_by"),
+ pl.col("upper_bound"),
+ pl.col("lower_bound"),
+ ).unique(),
+ how="left",
+ on=["strata", "stratified_by"],
+ )
+ )
+
+ reals_labels = ["real_negatives", "real_positives"]
+
+ reals_enum = pl.Enum(reals_labels)
+
+ reals_map = {0: "real_negatives", 1: "real_positives"}
+
+ data_to_adjust = data_to_adjust.with_columns(
+ pl.col("reals")
+ .replace_strict(reals_map, return_dtype=reals_enum)
+ .alias("reals_labels")
+ )
+
+ list_data_to_adjust = {
+ group[0]: df
+ for group, df in data_to_adjust.partition_by(
+ "reference_group", as_dict=True
+ ).items()
+ }
+
+ return list_data_to_adjust
+
+
+def _create_list_data_to_adjust(
+ aj_data_combinations: pl.DataFrame,
+ probs_dict: Dict[str, np.ndarray],
+ reals_dict: Union[np.ndarray, Dict[str, np.ndarray]],
+ times_dict: Union[np.ndarray, Dict[str, np.ndarray]],
+ stratified_by,
+ by,
+) -> Dict[str, pl.DataFrame]:
+ # reference_groups = list(probs_dict.keys())
+ reference_group_labels = list(probs_dict.keys())
+
+ if isinstance(reals_dict, dict):
+ num_keys_reals = len(reals_dict)
+ else:
+ num_keys_reals = 1
+
+ # num_reals = len(reals_dict)
+
+ reference_group_enum = pl.Enum(reference_group_labels)
+
+ strata_enum_dtype = aj_data_combinations.schema["strata"]
+
+ if len(probs_dict) == 1:
+ probs_array = np.asarray(probs_dict[reference_group_labels[0]])
+
+ if isinstance(reals_dict, dict):
+ reals_array = np.asarray(reals_dict[0])
+ else:
+ reals_array = np.asarray(reals_dict)
+
+ if isinstance(times_dict, dict):
+ times_array = np.asarray(times_dict[0])
+ else:
+ times_array = np.asarray(times_dict)
+
+ data_to_adjust = pl.DataFrame(
+ {
+ "reference_group": np.repeat(reference_group_labels, len(probs_array)),
+ "probs": probs_array,
+ "reals": reals_array,
+ "times": times_array,
+ }
+ ).with_columns(pl.col("reference_group").cast(reference_group_enum))
+
+ elif num_keys_reals == 1:
+ reals_array = np.asarray(reals_dict)
+ times_array = np.asarray(times_dict)
+ n = len(reals_array)
+
+ data_to_adjust = pl.DataFrame(
+ {
+ "reference_group": np.repeat(reference_group_labels, n),
+ "probs": np.concatenate(
+ [np.asarray(probs_dict[g]) for g in reference_group_labels]
+ ),
+ "reals": np.tile(reals_array, len(reference_group_labels)),
+ "times": np.tile(times_array, len(reference_group_labels)),
+ }
+ ).with_columns(pl.col("reference_group").cast(reference_group_enum))
+
+ elif isinstance(reals_dict, dict) and isinstance(times_dict, dict):
+ data_to_adjust = (
+ pl.DataFrame(
+ {
+ "reference_group": reference_group_labels,
+ "probs": list(probs_dict.values()),
+ "reals": list(reals_dict.values()),
+ "times": list(times_dict.values()),
+ }
+ )
+ .explode(["probs", "reals", "times"])
+ .with_columns(pl.col("reference_group").cast(reference_group_enum))
+ )
+
+ data_to_adjust = add_cutoff_strata(
+ data_to_adjust, by=by, stratified_by=stratified_by
+ )
+
+ data_to_adjust = pivot_longer_strata(data_to_adjust)
+
+ data_to_adjust = (
+ data_to_adjust.with_columns([pl.col("strata")])
+ .with_columns(pl.col("strata").cast(strata_enum_dtype))
+ .join(
+ aj_data_combinations.select(
+ pl.col("strata"),
+ pl.col("stratified_by"),
+ pl.col("upper_bound"),
+ pl.col("lower_bound"),
+ ).unique(),
+ how="left",
+ on=["strata", "stratified_by"],
+ )
+ )
+
+ reals_labels = [
+ "real_negatives",
+ "real_positives",
+ "real_competing",
+ "real_censored",
+ ]
+
+ reals_enum = pl.Enum(reals_labels)
+
+ # Map reals values to strings
+ reals_map = {0: "real_negatives", 2: "real_competing", 1: "real_positives"}
+
+ data_to_adjust = data_to_adjust.with_columns(
+ pl.col("reals")
+ .replace_strict(reals_map, return_dtype=reals_enum)
+ .alias("reals_labels")
+ )
+
+ # Partition by reference_group
+ list_data_to_adjust = {
+ group[0]: df
+ for group, df in data_to_adjust.partition_by(
+ "reference_group", as_dict=True
+ ).items()
+ }
+
+ return list_data_to_adjust
+
+
+def _cast_and_join_adjusted_data_binary(
+ aj_data_combinations: pl.DataFrame, aj_estimates_data: pl.DataFrame
+) -> pl.DataFrame:
+ strata_enum_dtype = aj_data_combinations.schema["strata"]
+
+ aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns(
+ pl.col("strata").cast(strata_enum_dtype)
+ )
+
+ final_adjusted_data_polars = (
+ (
+ aj_data_combinations.with_columns([pl.col("strata")]).join(
+ aj_estimates_data,
+ on=[
+ "strata",
+ "stratified_by",
+ "reals_labels",
+ "reference_group",
+ "chosen_cutoff",
+ ],
+ how="left",
+ )
+ )
+ .with_columns(
+ pl.when(
+ (
+ (pl.col("chosen_cutoff") >= pl.col("upper_bound"))
+ & (pl.col("stratified_by") == "probability_threshold")
+ )
+ | (
+ ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point"))
+ & (pl.col("stratified_by") == "ppcr")
+ )
+ )
+ .then(pl.lit("predicted_negatives"))
+ .otherwise(pl.lit("predicted_positives"))
+ .cast(pl.Enum(["predicted_negatives", "predicted_positives"]))
+ .alias("prediction_label")
+ )
+ .with_columns(
+ (
+ pl.when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_positives"))
+ )
+ .then(pl.lit("true_positives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_negatives"))
+ )
+ .then(pl.lit("false_positives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_negatives"))
+ )
+ .then(pl.lit("true_negatives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_positives"))
+ )
+ .then(pl.lit("false_negatives"))
+ .cast(
+ pl.Enum(
+ [
+ "true_positives",
+ "false_positives",
+ "true_negatives",
+ "false_negatives",
+ ]
+ )
+ )
+ ).alias("classification_outcome")
+ )
+ ).with_columns(pl.col("reals_estimate").fill_null(0))
+
+ return final_adjusted_data_polars
+
+
+def cast_and_join_adjusted_data(
+ aj_data_combinations, aj_estimates_data
+) -> pl.DataFrame:
+ strata_enum_dtype = aj_data_combinations.schema["strata"]
+
+ aj_estimates_data = aj_estimates_data.with_columns([pl.col("strata")]).with_columns(
+ pl.col("strata").cast(strata_enum_dtype)
+ )
+
+ final_adjusted_data_polars = (
+ aj_data_combinations.with_columns([pl.col("strata")])
+ .join(
+ aj_estimates_data,
+ on=[
+ "strata",
+ "fixed_time_horizon",
+ "censoring_heuristic",
+ "competing_heuristic",
+ "reals_labels",
+ "reference_group",
+ "chosen_cutoff",
+ "risk_set_scope",
+ ],
+ how="left",
+ )
+ .with_columns(
+ pl.when(
+ (
+ (pl.col("chosen_cutoff") >= pl.col("upper_bound"))
+ & (pl.col("stratified_by") == "probability_threshold")
+ )
+ | (
+ ((1 - pl.col("chosen_cutoff")) >= pl.col("mid_point"))
+ & (pl.col("stratified_by") == "ppcr")
+ )
+ )
+ .then(pl.lit("predicted_negatives"))
+ .otherwise(pl.lit("predicted_positives"))
+ .cast(pl.Enum(["predicted_negatives", "predicted_positives"]))
+ .alias("prediction_label")
+ )
+ .with_columns(
+ (
+ pl.when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_positives"))
+ )
+ .then(pl.lit("true_positives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_negatives"))
+ )
+ .then(pl.lit("false_positives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_negatives"))
+ )
+ .then(pl.lit("true_negatives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_positives"))
+ )
+ .then(pl.lit("false_negatives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_negatives"))
+ & (pl.col("reals_labels") == pl.lit("real_competing"))
+ & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative"))
+ )
+ .then(pl.lit("true_negatives"))
+ .when(
+ (pl.col("prediction_label") == pl.lit("predicted_positives"))
+ & (pl.col("reals_labels") == pl.lit("real_competing"))
+ & (pl.col("competing_heuristic") == pl.lit("adjusted_as_negative"))
+ )
+ .then(pl.lit("false_positives"))
+ .otherwise(pl.lit("excluded")) # or pl.lit(None) if you prefer nulls
+ .cast(
+ pl.Enum(
+ [
+ "true_positives",
+ "false_positives",
+ "true_negatives",
+ "false_negatives",
+ "excluded",
+ ]
+ )
+ )
+ ).alias("classification_outcome")
+ )
+ )
+ return final_adjusted_data_polars
+
+
+def _calculate_cumulative_aj_data_binary(aj_data: pl.DataFrame) -> pl.DataFrame:
+ cumulative_aj_data = (
+ aj_data.group_by(
+ [
+ "reference_group",
+ "stratified_by",
+ "chosen_cutoff",
+ "classification_outcome",
+ ]
+ )
+ .agg([pl.col("reals_estimate").sum()])
+ .pivot(on="classification_outcome", values="reals_estimate")
+ .with_columns(
+ [
+ pl.col(col).fill_null(0)
+ for col in [
+ "true_positives",
+ "true_negatives",
+ "false_positives",
+ "false_negatives",
+ ]
+ ]
+ )
+ .with_columns(
+ (pl.col("true_positives") + pl.col("false_positives")).alias(
+ "predicted_positives"
+ ),
+ (pl.col("true_negatives") + pl.col("false_negatives")).alias(
+ "predicted_negatives"
+ ),
+ (pl.col("true_positives") + pl.col("false_negatives")).alias(
+ "real_positives"
+ ),
+ (pl.col("false_positives") + pl.col("true_negatives")).alias(
+ "real_negatives"
+ ),
+ (
+ pl.col("true_positives")
+ + pl.col("true_negatives")
+ + pl.col("false_positives")
+ + pl.col("false_negatives")
+ )
+ .alias("n")
+ .sum(),
+ )
+ .with_columns(
+ (pl.col("true_positives") + pl.col("false_positives")).alias(
+ "predicted_positives"
+ ),
+ (pl.col("true_negatives") + pl.col("false_negatives")).alias(
+ "predicted_negatives"
+ ),
+ (pl.col("true_positives") + pl.col("false_negatives")).alias(
+ "real_positives"
+ ),
+ (pl.col("false_positives") + pl.col("true_negatives")).alias(
+ "real_negatives"
+ ),
+ (
+ pl.col("true_positives")
+ + pl.col("true_negatives")
+ + pl.col("false_positives")
+ + pl.col("false_negatives")
+ ).alias("n"),
+ )
+ )
+
+ return cumulative_aj_data
+
+
+def _calculate_cumulative_aj_data(aj_data: pl.DataFrame) -> pl.DataFrame:
+ cumulative_aj_data = (
+ aj_data.filter(pl.col("risk_set_scope") == "pooled_by_cutoff")
+ .group_by(
+ [
+ "reference_group",
+ "fixed_time_horizon",
+ "censoring_heuristic",
+ "competing_heuristic",
+ "stratified_by",
+ "chosen_cutoff",
+ "classification_outcome",
+ ]
+ )
+ .agg([pl.col("reals_estimate").sum()])
+ .pivot(on="classification_outcome", values="reals_estimate")
+ .fill_null(0)
+ .with_columns(
+ (pl.col("true_positives") + pl.col("false_positives")).alias(
+ "predicted_positives"
+ ),
+ (pl.col("true_negatives") + pl.col("false_negatives")).alias(
+ "predicted_negatives"
+ ),
+ (pl.col("true_positives") + pl.col("false_negatives")).alias(
+ "real_positives"
+ ),
+ (pl.col("false_positives") + pl.col("true_negatives")).alias(
+ "real_negatives"
+ ),
+ (
+ pl.col("true_positives")
+ + pl.col("true_negatives")
+ + pl.col("false_positives")
+ + pl.col("false_negatives")
+ ).alias("n"),
+ )
+ .with_columns(
+ (pl.col("true_positives") + pl.col("false_positives")).alias(
+ "predicted_positives"
+ ),
+ (pl.col("true_negatives") + pl.col("false_negatives")).alias(
+ "predicted_negatives"
+ ),
+ (pl.col("true_positives") + pl.col("false_negatives")).alias(
+ "real_positives"
+ ),
+ (pl.col("false_positives") + pl.col("true_negatives")).alias(
+ "real_negatives"
+ ),
+ (
+ pl.col("true_positives")
+ + pl.col("true_negatives")
+ + pl.col("false_positives")
+ + pl.col("false_negatives")
+ ).alias("n"),
+ )
+ )
+
+ return cumulative_aj_data
+
+
+def _turn_cumulative_aj_to_performance_data(
+ cumulative_aj_data: pl.DataFrame,
+) -> pl.DataFrame:
+ performance_data = cumulative_aj_data.with_columns(
+ (pl.col("true_positives") / pl.col("real_positives")).alias("sensitivity"),
+ (pl.col("true_negatives") / pl.col("real_negatives")).alias("specificity"),
+ (pl.col("true_positives") / pl.col("predicted_positives")).alias("ppv"),
+ (pl.col("true_negatives") / pl.col("predicted_negatives")).alias("npv"),
+ (pl.col("false_positives") / pl.col("real_negatives")).alias(
+ "false_positive_rate"
+ ),
+ (
+ (pl.col("true_positives") / pl.col("predicted_positives"))
+ / (pl.col("real_positives") / pl.col("n"))
+ ).alias("lift"),
+ pl.when(pl.col("stratified_by") == "probability_threshold")
+ .then(
+ (pl.col("true_positives") / pl.col("n"))
+ - (pl.col("false_positives") / pl.col("n"))
+ * pl.col("chosen_cutoff")
+ / (1 - pl.col("chosen_cutoff"))
+ )
+ .otherwise(None)
+ .alias("net_benefit"),
+ pl.when(pl.col("stratified_by") == "probability_threshold")
+ .then(
+ 100 * (pl.col("true_negatives") / pl.col("n"))
+ - (pl.col("false_negatives") / pl.col("n"))
+ * (1 - pl.col("chosen_cutoff"))
+ / pl.col("chosen_cutoff")
+ )
+ .otherwise(None)
+ .alias("net_benefit_interventions_avoided"),
+ pl.when(pl.col("stratified_by") == "probability_threshold")
+ .then(pl.col("predicted_positives") / pl.col("n"))
+ .otherwise(pl.col("chosen_cutoff"))
+ .alias("ppcr"),
+ )
+
+ return performance_data
diff --git a/src/rtichoke/summary_report/summary_report.py b/src/rtichoke/summary_report/summary_report.py
index 9549260..8506794 100644
--- a/src/rtichoke/summary_report/summary_report.py
+++ b/src/rtichoke/summary_report/summary_report.py
@@ -2,8 +2,10 @@
A module for Summary Report
"""
-from rtichoke.helpers.send_post_request_to_r_rtichoke import send_requests_to_rtichoke_r
-from rtichoke.helpers.sandbox_observable_helpers import (
+from rtichoke.processing.send_post_request_to_r_rtichoke import (
+ send_requests_to_rtichoke_r,
+)
+from rtichoke.processing.transforms import (
_create_list_data_to_adjust,
)
import subprocess
diff --git a/src/rtichoke/utility/decision.py b/src/rtichoke/utility/decision.py
index 50f0e6d..57fdf53 100644
--- a/src/rtichoke/utility/decision.py
+++ b/src/rtichoke/utility/decision.py
@@ -4,7 +4,7 @@
from typing import Dict, List, Sequence, Union
from plotly.graph_objs._figure import Figure
-from rtichoke.helpers.plotly_helper_functions import (
+from rtichoke.processing.plotly_helper_functions import (
_create_rtichoke_plotly_curve_binary,
_create_rtichoke_plotly_curve_times,
_plot_rtichoke_curve_binary,
diff --git a/tests/test_calibration.py b/tests/test_calibration.py
new file mode 100644
index 0000000..4e79687
--- /dev/null
+++ b/tests/test_calibration.py
@@ -0,0 +1,28 @@
+import numpy as np
+from rtichoke.calibration.calibration import create_calibration_curve
+
+
+def test_create_calibration_curve_smooth():
+ probs = {"model_1": np.linspace(0, 1, 100)}
+ reals = np.random.randint(0, 2, 100)
+ fig = create_calibration_curve(probs, reals, calibration_type="smooth")
+
+ # Check if the figure has the correct number of traces (smooth curve, histogram, and reference line)
+ assert len(fig.data) == 3
+
+ # Check reference line data
+ reference_line = fig.data[0]
+ assert reference_line.name == "Perfectly Calibrated"
+
+
+def test_create_calibration_curve_smooth_single_point():
+ probs = {"model_1": np.array([0.5] * 100)}
+ reals = np.random.randint(0, 2, 100)
+ fig = create_calibration_curve(probs, reals, calibration_type="smooth")
+
+ # Check that the plot mode is "lines+markers"
+ assert fig.data[1].mode == "lines+markers"
+
+ # Check histogram data
+ histogram = fig.data[2]
+ assert histogram.type == "bar"
diff --git a/tests/test_calibration_times.py b/tests/test_calibration_times.py
new file mode 100644
index 0000000..b6570c9
--- /dev/null
+++ b/tests/test_calibration_times.py
@@ -0,0 +1,24 @@
+import numpy as np
+from rtichoke.calibration import create_calibration_curve_times
+
+
+def test_create_calibration_curve_times():
+ probs = {"model_1": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])}
+ reals = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
+ times = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
+ fixed_time_horizons = [5, 10]
+ heuristics_sets = [
+ {"censoring_heuristic": "excluded", "competing_heuristic": "excluded"}
+ ]
+
+ fig = create_calibration_curve_times(
+ probs,
+ reals,
+ times,
+ fixed_time_horizons=fixed_time_horizons,
+ heuristics_sets=heuristics_sets,
+ )
+
+ assert fig is not None
+ assert len(fig.data) > 0
+ assert len(fig.layout.sliders) > 0
diff --git a/tests/test_heuristics.py b/tests/test_heuristics.py
new file mode 100644
index 0000000..6730c1c
--- /dev/null
+++ b/tests/test_heuristics.py
@@ -0,0 +1,97 @@
+import pytest
+import polars as pl
+from polars.testing import assert_frame_equal
+from rtichoke.calibration.calibration import _apply_heuristics_and_censoring
+
+
+@pytest.fixture
+def sample_data():
+ return pl.DataFrame(
+ {
+ "real": [1, 0, 2, 1, 2, 0, 1],
+ "time": [1, 2, 3, 8, 9, 10, 12],
+ }
+ )
+
+
+def test_competing_as_negative_logic(sample_data):
+ # Heuristics that shouldn't change data before horizon
+ result = _apply_heuristics_and_censoring(
+ sample_data, 15, "adjusted", "adjusted_as_negative"
+ )
+ # Competing events at times 3 and 9 should become 0.
+ expected = pl.DataFrame(
+ {
+ "real": [1, 0, 0, 1, 0, 0, 1],
+ "time": [1, 2, 3, 8, 9, 10, 12],
+ }
+ )
+ assert_frame_equal(result, expected)
+
+
+def test_admin_censoring(sample_data):
+ result = _apply_heuristics_and_censoring(
+ sample_data, 7, "adjusted", "adjusted_as_negative"
+ )
+ # Admin censoring for times > 7. Competing event at time=3 becomes 0.
+ expected = pl.DataFrame(
+ {
+ "real": [1, 0, 0, 0, 0, 0, 0],
+ "time": [1, 2, 3, 8, 9, 10, 12],
+ }
+ )
+ assert_frame_equal(result, expected)
+
+
+def test_censoring_excluded(sample_data):
+ result = _apply_heuristics_and_censoring(
+ sample_data, 10, "excluded", "adjusted_as_negative"
+ )
+ # Excludes censored at times 2, 10. Admin censors time > 10. Competing at 3,9 -> 0.
+ expected = pl.DataFrame(
+ {
+ "real": [1, 0, 1, 0, 0],
+ "time": [1, 3, 8, 9, 12],
+ }
+ )
+ assert_frame_equal(result.sort("time"), expected.sort("time"))
+
+
+def test_competing_excluded(sample_data):
+ result = _apply_heuristics_and_censoring(sample_data, 10, "adjusted", "excluded")
+ # Excludes competing at 3, 9. Admin censors time > 10.
+ expected = pl.DataFrame(
+ {
+ "real": [1, 0, 1, 0, 0],
+ "time": [1, 2, 8, 10, 12],
+ }
+ )
+ assert_frame_equal(result.sort("time"), expected.sort("time"))
+
+
+def test_competing_as_negative(sample_data):
+ result = _apply_heuristics_and_censoring(
+ sample_data, 10, "adjusted", "adjusted_as_negative"
+ )
+ # Competing at 3,9 -> 0. Admin censors time > 10.
+ expected = pl.DataFrame(
+ {
+ "real": [1, 0, 0, 1, 0, 0, 0],
+ "time": [1, 2, 3, 8, 9, 10, 12],
+ }
+ )
+ assert_frame_equal(result, expected)
+
+
+def test_competing_as_composite(sample_data):
+ result = _apply_heuristics_and_censoring(
+ sample_data, 10, "adjusted", "adjusted_as_composite"
+ )
+ # Competing at 3,9 -> 1. Admin censors time > 10.
+ expected = pl.DataFrame(
+ {
+ "real": [1, 0, 1, 1, 1, 0, 0],
+ "time": [1, 2, 3, 8, 9, 10, 12],
+ }
+ )
+ assert_frame_equal(result, expected)
diff --git a/tests/test_rtichoke.py b/tests/test_rtichoke.py
index 0ff1b91..1dd7916 100644
--- a/tests/test_rtichoke.py
+++ b/tests/test_rtichoke.py
@@ -2,7 +2,7 @@
A module for tests
"""
-from rtichoke.helpers.sandbox_observable_helpers import (
+from rtichoke.processing.adjustments import (
extract_aj_estimate_for_strata,
)
diff --git a/uv.lock b/uv.lock
index 7cf5135..a23d1da 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,5 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.9"
resolution-markers = [
"python_full_version >= '3.13'",
@@ -3888,7 +3888,7 @@ wheels = [
[[package]]
name = "rtichoke"
-version = "0.1.25"
+version = "0.1.26"
source = { editable = "." }
dependencies = [
{ name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
@@ -3899,6 +3899,7 @@ dependencies = [
{ name = "polarstate" },
{ name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "pyarrow", version = "22.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
+ { name = "statsmodels" },
{ name = "typing" },
]
@@ -3933,6 +3934,7 @@ requires-dist = [
{ name = "polars", specifier = ">=1.28.0" },
{ name = "polarstate", specifier = "==0.1.8" },
{ name = "pyarrow", specifier = ">=21.0.0" },
+ { name = "statsmodels", specifier = ">=0.14.0" },
{ name = "typing", specifier = ">=3.7.4.3" },
]