Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions examples/tcn_mimic3_codes.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TCN Model Training on MIMIC-III Dataset\n",
"\n",
"Train the TCN (Temporal Convolutional Networks) model for mortality prediction using the MIMIC-III dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.datasets import MIMIC3Dataset\n",
"\n",
"dataset = MIMIC3Dataset(\n",
" root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III\",\n",
" tables=[\"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\"],\n",
" dev=True,\n",
")\n",
"dataset.stats()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set Mortality Prediction Task\n",
"\n",
"We use the in-hospital mortality prediction task which predicts patient mortality based on diagnosis and procedure codes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.tasks import MortalityPredictionMIMIC3\n",
"\n",
"task = MortalityPredictionMIMIC3()\n",
"samples = dataset.set_task(task)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split Dataset\n",
"\n",
"Split the dataset into train, validation, and test sets using patient-level splitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.datasets import split_by_patient, get_dataloader\n",
"\n",
"train_dataset, val_dataset, test_dataset = split_by_patient(\n",
" samples, ratios=[0.7, 0.15, 0.15]\n",
")\n",
"\n",
"train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)\n",
"val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)\n",
"test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize TCN Model\n",
"\n",
"Create the TCN model with specified hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.models import TCN\n",
"\n",
"model = TCN(\n",
" dataset=samples,\n",
" embedding_dim=128,\n",
" num_channels=128,\n",
" kernel_size=2,\n",
" dropout=0.5,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train Model\n",
"\n",
"Train the model using the PyHealth Trainer with relevant metrics for mortality prediction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.trainer import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" metrics=[\"pr_auc\", \"roc_auc\", \"f1\", \"accuracy\"],\n",
")\n",
"\n",
"trainer.train(\n",
" train_dataloader=train_loader,\n",
" val_dataloader=val_loader,\n",
" epochs=10,\n",
" monitor=\"roc_auc\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate on Test Set\n",
"\n",
"Evaluate the trained model on the test set and print the results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results = trainer.evaluate(test_loader)\n",
"\n",
"print(\"Test Set Results:\")\n",
"print(f\" ROC-AUC: {results['roc_auc']:.4f}\")\n",
"print(f\" PR-AUC: {results['pr_auc']:.4f}\")\n",
"print(f\" F1 Score: {results['f1']:.4f}\")\n",
"print(f\" Accuracy: {results['accuracy']:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom TCN Configuration\n",
"\n",
"You can customize the TCN architecture by specifying different parameters:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create TCN with custom architecture\n",
"custom_model = TCN(\n",
" dataset=samples,\n",
" embedding_dim=64,\n",
" num_channels=[64, 128, 256], # List for manual layer specification\n",
" kernel_size=3,\n",
" dropout=0.3,\n",
")\n",
"\n",
"print(\"Custom TCN architecture:\")\n",
"print(f\"Embedding dim: {custom_model.embedding_dim}\")\n",
"print(f\"Output channels: {custom_model.num_channels}\")\n",
"print(f\"Number of features: {len(custom_model.feature_keys)}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading