Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@
MutationPathogenicityPrediction,
VariantClassificationClinVar,
)
from .dreamt_osa import DREAMTOSAClassification
200 changes: 200 additions & 0 deletions pyhealth/tasks/dreamt_osa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Any, Dict, List, Optional

import polars as pl

from .base_task import BaseTask


class DREAMTOSAClassification(BaseTask):
"""Patient-level OSA outcome tasks for the DREAMT dataset.

This task assumes you are using `DREAMTDataset`, which exposes patient-level
metadata such as:

- age
- gender
- bmi
- oahi
- ahi
- mean_sao2
- arousal_index
- medical_history
- sleep_disorders

We define OSA-related tasks using the AHI/OAHI values:

Tasks
-----
- "ahi_severity_4class"
Multi-class OSA severity from AHI:
0: AHI < 5 (normal)
1: 5 <= AHI < 15 (mild)
2: 15 <= AHI < 30 (moderate)
3: AHI >= 30 (severe)

- "ahi_binary_15"
Binary classification:
0: AHI < 15 (no / mild OSA)
1: AHI >= 15 (moderate / severe)

- "oahi_binary_5"
Binary classification:
0: OAHI < 5 (no sleep apnea by OAHI)
1: OAHI >= 5 (sleep apnea by OAHI)

Features
--------
By default, the feature vector for each patient is a dictionary of
clinical variables:

["age", "gender", "bmi", "mean_sao2", "arousal_index",
"medical_history", "sleep_disorders"]

You can override this with `feature_keys` when instantiating the task.

Example
-------
>>> from pyhealth.datasets import DREAMTDataset
>>> from pyhealth.tasks import DREAMTOSAClassification
>>>
>>> dataset = DREAMTDataset(root="/path/to/dreamt/version")
>>> task = DREAMTOSAClassification(task="ahi_severity_4class")
>>>
>>> # later, in your dataloader construction:
>>> # samples = task(patient) # where `patient` is from DREAMTDataset
"""

# Registry of supported tasks
tasks = {
"patient_level": [
"ahi_severity_4class",
"ahi_binary_15",
"oahi_binary_5",
]
}

def __init__(
self,
task: str,
feature_keys: Optional[List[str]] = None,
) -> None:
if task not in self.tasks["patient_level"]:
raise ValueError(
f"Unsupported task '{task}'. "
f"Choose from: {self.tasks['patient_level']}"
)

self.task = task
self.task_name = f"DREAMTOSA/{task}"

# Default clinical features to include
self.feature_keys = feature_keys or [
"age",
"gender",
"bmi",
"mean_sao2",
"arousal_index",
"medical_history",
"sleep_disorders",
]

# Patient-level tabular input
self.input_schema = {"feature": "tabular"}

# Label type depends on which task is chosen
if task == "ahi_severity_4class":
self.output_schema = {"label": "multiclass"} # classes 0–3
elif task in {"ahi_binary_15", "oahi_binary_5"}:
self.output_schema = {"label": "binary"}
else:
# Should not happen because of the check above
raise ValueError(f"Unknown task: {task}")

def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
"""Optionally filter events before task construction.

For patient-level OSA tasks, we typically just pass everything through.
You can customize this to drop unrelated event types if needed.
"""
return df

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Build one or more samples for a single patient.

Each patient contributes exactly one sample for these tasks.
"""
# Try to get split information if available, otherwise default to "train"
split = "train"
try:
split_events = patient.get_events("splits")
if len(split_events) == 1 and hasattr(split_events[0], "split"):
split = split_events[0].split
except Exception:
# If there is no "splits" table, we just keep the default.
pass

# Extract label depending on task
label = self._get_label_from_patient(patient)
if label is None:
# If we cannot compute a label (e.g., missing AHI/OAHI), skip this patient
return []

# Build feature dictionary from patient attributes
features: Dict[str, Any] = {}
for key in self.feature_keys:
value = getattr(patient, key, None)
features[key] = value

sample = {
"feature": features,
"label": label,
"split": split,
"patient_id": getattr(patient, "patient_id", None),
}
return [sample]

# ---------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------

def _get_label_from_patient(self, patient: Any) -> Optional[int]:
"""Compute the task-specific label from patient metadata."""
if self.task == "ahi_severity_4class":
ahi = getattr(patient, "ahi", None)
if ahi is None:
return None
try:
ahi_val = float(ahi)
except (TypeError, ValueError):
return None

if ahi_val < 5:
return 0 # normal
elif ahi_val < 15:
return 1 # mild
elif ahi_val < 30:
return 2 # moderate
else:
return 3 # severe

elif self.task == "ahi_binary_15":
ahi = getattr(patient, "ahi", None)
if ahi is None:
return None
try:
ahi_val = float(ahi)
except (TypeError, ValueError):
return None
return int(ahi_val >= 15.0) # 1 = moderate/severe

elif self.task == "oahi_binary_5":
oahi = getattr(patient, "oahi", None)
if oahi is None:
return None
try:
oahi_val = float(oahi)
except (TypeError, ValueError):
return None
return int(oahi_val >= 5.0)

return None
117 changes: 117 additions & 0 deletions tests/core/test_dreamt_osa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np # not strictly needed, but kept for consistency
import polars as pl

from pyhealth.data import Patient
from pyhealth.tasks import DREAMTOSAClassification


def test_dreamtosa_ahi_severity_4class():
# Minimal event-level DataFrame (no splits → default split="train")
df = pl.DataFrame(
{
"timestamp": [1],
"patient_id": ["P1"],
"visit_id": ["V1"],
"event_type": ["dreamt_sleep"],
}
)

# Create patient from raw event DataFrame
patient = Patient(
patient_id="P1",
data_source=df,
)

# Add DREAMT metadata attributes expected by the task
# (These are normally populated by DREAMTDataset / BaseDataset)
patient.ahi = 22.5 # moderate OSA → class 2
patient.age = 50
patient.gender = "F"
patient.bmi = 30.1
patient.mean_sao2 = 94.2
patient.arousal_index = 10.5
patient.medical_history = "HTN"
patient.sleep_disorders = "OSA"

# Initialize task
task = DREAMTOSAClassification(task="ahi_severity_4class")

# Run task
samples = task(patient)

# One sample per patient
assert len(samples) == 1
sample = samples[0]

# Check keys
assert "feature" in sample
assert "label" in sample
assert "split" in sample
assert "patient_id" in sample

# Check label mapping: AHI=22.5 → 2 (moderate)
assert sample["label"] == 2

# Check feature dict contents
features = sample["feature"]
assert isinstance(features, dict)
for key in [
"age",
"gender",
"bmi",
"mean_sao2",
"arousal_index",
"medical_history",
"sleep_disorders",
]:
assert key in features

# With no splits table, default split is "train"
assert sample["split"] == "train"
assert sample["patient_id"] == "P1"


def test_dreamtosa_ahi_binary_15_with_splits():
# Build event-level DataFrame including a "splits" event
df = pl.DataFrame(
{
"timestamp": [1, 2],
"patient_id": ["P2", "P2"],
"visit_id": ["V1", "V1"],
"event_type": ["dreamt_sleep", "splits"],
"split": [None, "val"], # only used for the "splits" row
}
)

patient = Patient(
patient_id="P2",
data_source=df,
)

# Add DREAMT metadata
patient.ahi = 18.0 # >= 15 → label 1
patient.age = 60
patient.gender = "M"
patient.bmi = 28.0
patient.mean_sao2 = 93.0
patient.arousal_index = 15.0
patient.medical_history = "DM2"
patient.sleep_disorders = "OSA"

task = DREAMTOSAClassification(task="ahi_binary_15")
samples = task(patient)

assert len(samples) == 1
sample = samples[0]

# Binary label from AHI >= 15
assert sample["label"] == 1

# Split should be picked up from the "splits" event
assert sample["split"] == "val"
assert sample["patient_id"] == "P2"

features = sample["feature"]
assert isinstance(features, dict)
assert features["age"] == 60
assert features["gender"] == "M"