From 13287f753869f85be99c1036f0c8f11e06629f2c Mon Sep 17 00:00:00 2001 From: bilalarif Date: Mon, 8 Dec 2025 23:23:10 -0600 Subject: [PATCH 01/10] Migrated TCN model to PyHealth 2.0 architecture --- pyhealth/models/tcn.py | 420 ++++++++++++----------------------------- 1 file changed, 124 insertions(+), 296 deletions(-) diff --git a/pyhealth/models/tcn.py b/pyhealth/models/tcn.py index cf205d9ca..4afc8f70c 100644 --- a/pyhealth/models/tcn.py +++ b/pyhealth/models/tcn.py @@ -1,16 +1,14 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn -import torch.nn.utils.rnn as rnn_utils -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm -from pyhealth.datasets import SampleEHRDataset +from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel -from pyhealth.models.utils import get_last_visit -# VALID_OPERATION_LEVEL = ["visit", "event"] +from .embedding import EmbeddingModel # From TCN original paper https://github.com/locuslab/TCN @@ -20,6 +18,8 @@ def __init__(self, chomp_size): self.chomp_size = chomp_size def forward(self, x): + if self.chomp_size == 0: + return x return x[:, :, : -self.chomp_size].contiguous() @@ -116,27 +116,64 @@ class TCNLayer(nn.Module): def __init__( self, input_dim: int, - num_channels: int = 128, + num_channels: Union[int, List[int]] = 128, max_seq_length: int = 20, kernel_size: int = 2, dropout: float = 0.5, ): super(TCNLayer, self).__init__() - self.num_channels = num_channels + + # Validate kernel_size + if kernel_size < 2: + raise ValueError( + f"kernel_size must be >= 2 for TCN, got {kernel_size}. " + "kernel_size=1 would result in no temporal modeling." + ) layers = [] # We compute automatically the depth based on the desired seq_length. - if isinstance(num_channels, int) and max_seq_length: + if isinstance(num_channels, int): + if not max_seq_length: + raise ValueError( + "max_seq_length must be provided when num_channels is int" + ) + # Validate max_seq_length + if max_seq_length <= 0: + raise ValueError( + f"max_seq_length must be positive, got {max_seq_length}" + ) + if max_seq_length < 2 * kernel_size: + raise ValueError( + f"max_seq_length must be >= 2 * kernel_size ({2 * kernel_size}) " + f"for automatic depth calculation, got {max_seq_length}. " + f"Either increase max_seq_length or provide num_channels as a list." + ) + # Validate num_channels value + if num_channels <= 0: + raise ValueError( + f"num_channels must be positive, got {num_channels}" + ) num_channels = [num_channels] * int( np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size)) ) - elif isinstance(num_channels, int) and not max_seq_length: - raise Exception( - "a maximum sequence length needs to be provided if num_channels is int" + # num_channels is now always a list + + # Validate num_channels list is not empty and all elements are positive + if not num_channels or len(num_channels) == 0: + raise ValueError( + "num_channels must be a non-empty list or a positive integer" ) - else: - pass + if isinstance(num_channels, list): + for i, nc in enumerate(num_channels): + if nc <= 0: + raise ValueError( + f"All num_channels values must be positive, " + f"got {nc} at index {i}" + ) + + # Store the actual output dimension (last layer's output size) + self.num_channels = num_channels[-1] num_levels = len(num_channels) for i in range(num_levels): @@ -159,9 +196,9 @@ def __init__( def forward( self, - x: torch.tensor, - mask: Optional[torch.tensor] = None, - ) -> Tuple[torch.tensor, torch.tensor]: + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward propagation. Args: @@ -170,14 +207,29 @@ def forward( 1 indicates valid and 0 indicates invalid. Returns: - last_out: a tensor of shape [batch size, hidden size], containing - the output features for the last time step. - out: a tensor of shape [batch size, sequence len, hidden size], + outputs: a tensor of shape [batch size, sequence len, hidden size], containing the output features for each time step. + last_outputs: a tensor of shape [batch size, hidden size], containing + the output features for the last time step. """ - out = self.network(x.permute(0, 2, 1)).permute(0, 2, 1) - last_out = get_last_visit(out, mask) - return last_out, out + batch_size = x.size(0) + # TCN expects (batch, channels, seq_len) so we permute + outputs = self.network(x.permute(0, 2, 1)).permute(0, 2, 1) + + # Extract last valid output using mask (similar to RNN) + if mask is None: + lengths = torch.full( + size=(batch_size,), fill_value=x.size(1), dtype=torch.int64, device=x.device + ) + else: + # Ensure mask is on the same device as x to avoid device mismatch + mask = mask.to(x.device) + lengths = torch.sum(mask.int(), dim=-1) + + # Clamp lengths to at least 1 to handle empty sequences + lengths = torch.clamp(lengths, min=1) + last_outputs = outputs[torch.arange(batch_size, device=x.device), (lengths - 1), :] + return outputs, last_outputs class TCN(BaseModel): @@ -189,165 +241,64 @@ class TCN(BaseModel): Note: We use separate TCN layers for different feature_keys. - Currently, we automatically support different input formats: - - code based input (need to use the embedding table later) - - float/int based value input - We follow the current convention for the TCN model: - - case 1. [code1, code2, code3, ...] - - we will assume the code follows the order; our model will encode - each code into a vector and apply TCN on the code level - - case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...] - - we will assume the inner bracket follows the order; our model first - use the embedding table to encode each code into a vector and then use - average/mean pooling to get one vector for one inner bracket; then use - TCN one the braket level - - case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...] - - this case only makes sense when each inner bracket has the same length; - we assume each dimension has the same meaning; we run TCN directly - on the inner bracket level, similar to case 1 after embedding table - - case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...] - - this case only makes sense when each inner bracket has the same length; - we assume each dimension has the same meaning; we run TCN directly - on the inner bracket level, similar to case 2 after embedding table + Currently, we support two types of input formats: + - Sequence of codes (e.g., diagnosis codes, procedure codes) + - Input format: (batch_size, sequence_length) + - Each code is embedded into a vector and TCN is applied on the sequence + - Timeseries values (e.g., lab tests, vital signs) + - Input format: (batch_size, sequence_length, num_features) + - Each timestep contains a fixed number of measurements + - TCN is applied directly on the timeseries data Args: - dataset: the dataset to train the model. It is used to query certain - information such as the set of all tokens. - feature_keys: list of keys in samples to use as features, - e.g. ["conditions", "procedures"]. - label_key: key in samples to use as label (e.g., "drugs"). - mode: one of "binary", "multiclass", or "multilabel". - embedding_dim: the embedding dimension. Default is 128. - num_channels: the number of channels in the TCN layer. Default is 128. - **kwargs: other parameters for the TCN layer. - - Examples: - >>> from pyhealth.datasets import SampleEHRDataset - >>> samples = [ - ... { - ... "patient_id": "patient-0", - ... "visit_id": "visit-0", - ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC - ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], - ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 - ... "list_list_vectors": [ - ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], - ... [[7.7, 8.5, 9.4]], - ... ], - ... "label": 1, - ... }, - ... { - ... "patient_id": "patient-0", - ... "visit_id": "visit-1", - ... "list_codes": [ - ... "55154191800", - ... "551541928", - ... "55154192800", - ... "705182798", - ... "70518279800", - ... ], - ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], - ... "list_list_codes": [["A04A", "B035", "C129"]], - ... "list_list_vectors": [ - ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], - ... ], - ... "label": 0, - ... }, - ... ] - >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") - >>> - >>> from pyhealth.models import TCN - >>> model = TCN( - ... dataset=dataset, - ... feature_keys=[ - ... "list_codes", - ... "list_vectors", - ... "list_list_codes", - ... "list_list_vectors", - ... ], - ... label_key="label", - ... mode="binary", - ... ) - >>> - >>> from pyhealth.datasets import get_dataloader - >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) - >>> data_batch = next(iter(train_loader)) - >>> - >>> ret = model(**data_batch) - >>> print(ret) - { - 'loss': tensor(1.1641, grad_fn=), - 'y_prob': tensor([[0.6837], - [0.3081]], grad_fn=), - 'y_true': tensor([[0.], - [1.]]), - 'logit': tensor([[ 0.7706], - [-0.8091]], grad_fn=) - } - >>> - - + dataset (SampleDataset): the dataset to train the model. It is used to query certain + information such as the set of all tokens. The dataset's input_schema and + output_schema define the feature_keys, label_key, and mode. + embedding_dim (int): the embedding dimension. Default is 128. + num_channels (Union[int, List[int]]): the number of channels in the TCN layer. + If int, depth is auto-computed from max_seq_length. If list, specifies + channels for each layer. Default is 128. + **kwargs: other parameters for the TCN layer (e.g., max_seq_length, kernel_size, dropout). """ def __init__( self, - dataset: SampleEHRDataset, - feature_keys: List[str], - label_key: str, - mode: str, + dataset: SampleDataset, embedding_dim: int = 128, - num_channels: int = 128, + num_channels: Union[int, List[int]] = 128, **kwargs ): super(TCN, self).__init__( dataset=dataset, - feature_keys=feature_keys, - label_key=label_key, - mode=mode, ) self.embedding_dim = embedding_dim - self.num_channels = num_channels - + # validate kwargs for TCN layer if "input_dim" in kwargs: raise ValueError("input_dim is determined by embedding_dim") + assert len(self.label_keys) == 1, "Only one label key is supported if TCN is initialized" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + # Validate that we have at least one feature + if not self.feature_keys or len(self.feature_keys) == 0: + raise ValueError( + "TCN requires at least one feature key. " + "Please provide feature_keys in your dataset configuration." + ) - # the key of self.feat_tokenizers only contains the code based inputs - self.feat_tokenizers = {} - self.label_tokenizer = self.get_label_tokenizer() - # the key of self.embeddings only contains the code based inputs - self.embeddings = nn.ModuleDict() - # the key of self.linear_layers only contains the float/int based inputs - self.linear_layers = nn.ModuleDict() - - for feature_key in self.feature_keys: - input_info = self.dataset.input_info[feature_key] - # sanity check - if input_info["type"] not in [str, float, int]: - raise ValueError( - "TCN only supports str code, float and int as input types" - ) - elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]): - raise ValueError( - "TCN only supports 2-dim or 3-dim str code as input types" - ) - elif (input_info["type"] in [float, int]) and ( - input_info["dim"] not in [2, 3] - ): - raise ValueError( - "TCN only supports 2-dim or 3-dim float and int as input types" - ) - else: - pass - # for code based input, we need Type - # for float/int based input, we need Type, input_dim - self.add_feature_transform_layer(feature_key, input_info) + self.embedding_model = EmbeddingModel(dataset, embedding_dim) self.tcn = nn.ModuleDict() - for feature_key in feature_keys: + for feature_key in self.feature_keys: self.tcn[feature_key] = TCNLayer( input_dim=embedding_dim, num_channels=num_channels, **kwargs ) - output_size = self.get_output_size(self.label_tokenizer) + + # Get the actual output dimension from TCNLayer instances + # All TCNLayers have the same output dimension + self.num_channels = next(iter(self.tcn.values())).num_channels + + output_size = self.get_output_size() self.fc = nn.Linear(len(self.feature_keys) * self.num_channels, output_size) def forward(self, **kwargs) -> Dict[str, torch.Tensor]: @@ -360,152 +311,29 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: all the feature keys and the label key. Returns: - A dictionary with the following keys: - loss: a scalar tensor representing the loss. - y_prob: a tensor representing the predicted probabilities. - y_true: a tensor representing the true labels. + Dict[str, torch.Tensor]: A dictionary with the following keys: + - loss: a scalar tensor representing the loss. + - y_prob: a tensor representing the predicted probabilities. + - y_true: a tensor representing the true labels. + - logit: a tensor representing the logits. + - embed (optional): a tensor representing the patient embeddings if requested. """ patient_emb = [] + embedded = self.embedding_model(kwargs) for feature_key in self.feature_keys: - input_info = self.dataset.input_info[feature_key] - dim_, type_ = input_info["dim"], input_info["type"] - - # for case 1: [code1, code2, code3, ...] - if (dim_ == 2) and (type_ == str): - x = self.feat_tokenizers[feature_key].batch_encode_2d( - kwargs[feature_key] - ) - # (patient, event) - x = torch.tensor(x, dtype=torch.long, device=self.device) - # (patient, event, embedding_dim) - x = self.embeddings[feature_key](x) - # (patient, event) - mask = torch.any(x !=0, dim=2) - - # for case 2: [[code1, code2], [code3, ...], ...] - elif (dim_ == 3) and (type_ == str): - x = self.feat_tokenizers[feature_key].batch_encode_3d( - kwargs[feature_key] - ) - # (patient, visit, event) - x = torch.tensor(x, dtype=torch.long, device=self.device) - # (patient, visit, event, embedding_dim) - x = self.embeddings[feature_key](x) - # (patient, visit, embedding_dim) - x = torch.sum(x, dim=2) - # (patient, visit) - mask = torch.any(x !=0, dim=2) - - # for case 3: [[1.5, 2.0, 0.0], ...] - elif (dim_ == 2) and (type_ in [float, int]): - x, mask = self.padding2d(kwargs[feature_key]) - # (patient, event, values) - x = torch.tensor(x, dtype=torch.float, device=self.device) - # (patient, event, embedding_dim) - x = self.linear_layers[feature_key](x) - # (patient, event) - mask = mask.bool().to(self.device) - - # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] - elif (dim_ == 3) and (type_ in [float, int]): - x, mask = self.padding3d(kwargs[feature_key]) - # (patient, visit, event, values) - x = torch.tensor(x, dtype=torch.float, device=self.device) - # (patient, visit, embedding_dim) - x = torch.sum(x, dim=2) - x = self.linear_layers[feature_key](x) - # (patient, event) - mask = mask[:, :, 0] - mask = mask.bool().to(self.device) - - else: - raise NotImplementedError - - x, _ = self.tcn[feature_key](x, mask) + x = embedded[feature_key] + mask = (x.sum(dim=-1) != 0).int() + _, x = self.tcn[feature_key](x, mask) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob - y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) - results = { - "loss": loss, - "y_prob": y_prob, - "y_true": y_true, - "logit": logits, - } + results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} if kwargs.get("embed", False): results["embed"] = patient_emb return results - - -if __name__ == "__main__": - from pyhealth.datasets import SampleEHRDataset - - samples = [ - { - "patient_id": "patient-0", - "visit_id": "visit-0", - # "single_vector": [1, 2, 3], - "list_codes": ["505800458", "50580045810", "50580045811"], # NDC - "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], - "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 - "list_list_vectors": [ - [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], - [[7.7, 8.5, 9.4]], - ], - "label": 1, - }, - { - "patient_id": "patient-0", - "visit_id": "visit-1", - # "single_vector": [1, 5, 8], - "list_codes": [ - "55154191800", - "551541928", - "55154192800", - "705182798", - "70518279800", - ], - "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], - "list_list_codes": [["A04A", "B035", "C129"]], - "list_list_vectors": [ - [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], - ], - "label": 0, - }, - ] - - # dataset - dataset = SampleEHRDataset(samples=samples, dataset_name="test") - - # data loader - from pyhealth.datasets import get_dataloader - - train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) - - # model - model = TCN( - dataset=dataset, - feature_keys=[ - "list_codes", - "list_vectors", - "list_list_codes", - "list_list_vectors", - ], - label_key="label", - mode="binary", - ) - - # data batch - data_batch = next(iter(train_loader)) - - # try the model - ret = model(**data_batch) - print(ret) - - # try loss backward - ret["loss"].backward() From 74fc26482a6fa7aebd84ca9af097835ac79943b8 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Mon, 8 Dec 2025 23:24:00 -0600 Subject: [PATCH 02/10] added test cases --- tests/core/test_tcn.py | 204 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 tests/core/test_tcn.py diff --git a/tests/core/test_tcn.py b/tests/core/test_tcn.py new file mode 100644 index 000000000..b398a127a --- /dev/null +++ b/tests/core/test_tcn.py @@ -0,0 +1,204 @@ +import unittest +import torch + +from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.models import TCN + + +class TestTCN(unittest.TestCase): + """Test cases for the TCN model.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "procedures": ["proc-1", "proc-2", "proc-3"], + "label": 0, + }, + { + "patient_id": "patient-0", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": ["proc-1", "proc-2"], + "label": 1, + }, + ] + + # Define input and output schemas + self.input_schema = { + "conditions": "sequence", # sequence of condition codes + "procedures": "sequence", # sequence of procedure codes + } + self.output_schema = {"label": "binary"} # binary classification + + # Create dataset + self.dataset = SampleDataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + # Create model + self.model = TCN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the TCN model initializes correctly.""" + self.assertIsInstance(self.model, TCN) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(self.model.num_channels, 128) + self.assertEqual(len(self.model.feature_keys), 2) + self.assertIn("conditions", self.model.feature_keys) + self.assertIn("procedures", self.model.feature_keys) + self.assertEqual(self.model.label_key, "label") + + def test_model_forward(self): + """Test that the TCN model forward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check output structure + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + # Check tensor shapes + self.assertEqual(ret["y_prob"].shape[0], 2) # batch size + self.assertEqual(ret["y_true"].shape[0], 2) # batch size + self.assertEqual(ret["logit"].shape[0], 2) # batch size + + # Check that loss is a scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the TCN model backward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + ret = self.model(**data_batch) + + # Backward pass + ret["loss"].backward() + + # Check that at least one parameter has gradients (backward working) + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_model_with_embedding(self): + """Test that the TCN model returns embeddings when requested.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check that embeddings are returned + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + # Check embedding dimension (2 features * num_channels) + expected_embed_dim = len(self.model.feature_keys) * self.model.num_channels + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_custom_hyperparameters(self): + """Test TCN model with custom hyperparameters.""" + model = TCN( + dataset=self.dataset, + embedding_dim=64, + num_channels=64, + kernel_size=3, + dropout=0.3, + ) + + self.assertEqual(model.embedding_dim, 64) + self.assertEqual(model.num_channels, 64) + + # Test forward pass + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_num_channels_as_list(self): + """Test TCN model with num_channels as a list.""" + model = TCN( + dataset=self.dataset, + embedding_dim=64, + num_channels=[64, 128, 256], + ) + + # Should use the last value in the list + self.assertEqual(model.num_channels, 256) + + # Test forward pass + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_validation_kernel_size(self): + """Test that invalid kernel_size raises ValueError.""" + with self.assertRaises(ValueError) as context: + TCN( + dataset=self.dataset, + kernel_size=1, + ) + self.assertIn("kernel_size must be >= 2", str(context.exception)) + + def test_validation_negative_num_channels(self): + """Test that negative num_channels raises ValueError.""" + with self.assertRaises(ValueError) as context: + TCN( + dataset=self.dataset, + num_channels=-10, + ) + self.assertIn("must be positive", str(context.exception)) + + def test_validation_empty_num_channels_list(self): + """Test that empty num_channels list raises ValueError.""" + with self.assertRaises(ValueError) as context: + TCN( + dataset=self.dataset, + num_channels=[], + ) + self.assertIn("non-empty", str(context.exception)) + + def test_validation_negative_in_num_channels_list(self): + """Test that negative value in num_channels list raises ValueError.""" + with self.assertRaises(ValueError) as context: + TCN( + dataset=self.dataset, + num_channels=[128, -64, 256], + ) + self.assertIn("must be positive", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() From d52a01795da7407345f95e11524288172726f887 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Tue, 9 Dec 2025 00:57:07 -0600 Subject: [PATCH 03/10] Removed extra validation checks --- pyhealth/models/tcn.py | 54 ++---------------------------------------- tests/core/test_tcn.py | 36 ---------------------------- 2 files changed, 2 insertions(+), 88 deletions(-) diff --git a/pyhealth/models/tcn.py b/pyhealth/models/tcn.py index 4afc8f70c..0dc2b04bb 100644 --- a/pyhealth/models/tcn.py +++ b/pyhealth/models/tcn.py @@ -36,8 +36,7 @@ def __init__( stride=stride, padding=padding, dilation=dilation, - ), - dim=None, + ) ) self.chomp1 = Chomp1d(padding) self.relu1 = nn.ReLU() @@ -51,8 +50,7 @@ def __init__( stride=stride, padding=padding, dilation=dilation, - ), - dim=None, + ) ) self.chomp2 = Chomp1d(padding) self.relu2 = nn.ReLU() @@ -123,54 +121,13 @@ def __init__( ): super(TCNLayer, self).__init__() - # Validate kernel_size - if kernel_size < 2: - raise ValueError( - f"kernel_size must be >= 2 for TCN, got {kernel_size}. " - "kernel_size=1 would result in no temporal modeling." - ) - layers = [] # We compute automatically the depth based on the desired seq_length. if isinstance(num_channels, int): - if not max_seq_length: - raise ValueError( - "max_seq_length must be provided when num_channels is int" - ) - # Validate max_seq_length - if max_seq_length <= 0: - raise ValueError( - f"max_seq_length must be positive, got {max_seq_length}" - ) - if max_seq_length < 2 * kernel_size: - raise ValueError( - f"max_seq_length must be >= 2 * kernel_size ({2 * kernel_size}) " - f"for automatic depth calculation, got {max_seq_length}. " - f"Either increase max_seq_length or provide num_channels as a list." - ) - # Validate num_channels value - if num_channels <= 0: - raise ValueError( - f"num_channels must be positive, got {num_channels}" - ) num_channels = [num_channels] * int( np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size)) ) - # num_channels is now always a list - - # Validate num_channels list is not empty and all elements are positive - if not num_channels or len(num_channels) == 0: - raise ValueError( - "num_channels must be a non-empty list or a positive integer" - ) - if isinstance(num_channels, list): - for i, nc in enumerate(num_channels): - if nc <= 0: - raise ValueError( - f"All num_channels values must be positive, " - f"got {nc} at index {i}" - ) # Store the actual output dimension (last layer's output size) self.num_channels = num_channels[-1] @@ -279,13 +236,6 @@ def __init__( self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] - # Validate that we have at least one feature - if not self.feature_keys or len(self.feature_keys) == 0: - raise ValueError( - "TCN requires at least one feature key. " - "Please provide feature_keys in your dataset configuration." - ) - self.embedding_model = EmbeddingModel(dataset, embedding_dim) self.tcn = nn.ModuleDict() diff --git a/tests/core/test_tcn.py b/tests/core/test_tcn.py index b398a127a..9fd94f016 100644 --- a/tests/core/test_tcn.py +++ b/tests/core/test_tcn.py @@ -163,42 +163,6 @@ def test_num_channels_as_list(self): self.assertIn("loss", ret) self.assertIn("y_prob", ret) - def test_validation_kernel_size(self): - """Test that invalid kernel_size raises ValueError.""" - with self.assertRaises(ValueError) as context: - TCN( - dataset=self.dataset, - kernel_size=1, - ) - self.assertIn("kernel_size must be >= 2", str(context.exception)) - - def test_validation_negative_num_channels(self): - """Test that negative num_channels raises ValueError.""" - with self.assertRaises(ValueError) as context: - TCN( - dataset=self.dataset, - num_channels=-10, - ) - self.assertIn("must be positive", str(context.exception)) - - def test_validation_empty_num_channels_list(self): - """Test that empty num_channels list raises ValueError.""" - with self.assertRaises(ValueError) as context: - TCN( - dataset=self.dataset, - num_channels=[], - ) - self.assertIn("non-empty", str(context.exception)) - - def test_validation_negative_in_num_channels_list(self): - """Test that negative value in num_channels list raises ValueError.""" - with self.assertRaises(ValueError) as context: - TCN( - dataset=self.dataset, - num_channels=[128, -64, 256], - ) - self.assertIn("must be positive", str(context.exception)) - if __name__ == "__main__": unittest.main() From 2f78a3f2e85dc9f632118f05da76b7b2ee0b5869 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Tue, 9 Dec 2025 01:04:34 -0600 Subject: [PATCH 04/10] removed space --- pyhealth/models/tcn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyhealth/models/tcn.py b/pyhealth/models/tcn.py index 0dc2b04bb..ebe36f643 100644 --- a/pyhealth/models/tcn.py +++ b/pyhealth/models/tcn.py @@ -120,7 +120,6 @@ def __init__( dropout: float = 0.5, ): super(TCNLayer, self).__init__() - layers = [] # We compute automatically the depth based on the desired seq_length. From aefade2e2908eeb9c7d212d6ffc510f35c2eed10 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Tue, 9 Dec 2025 01:48:19 -0600 Subject: [PATCH 05/10] added notebook for tcn --- examples/tcn_mimic3.ipynb | 207 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 examples/tcn_mimic3.ipynb diff --git a/examples/tcn_mimic3.ipynb b/examples/tcn_mimic3.ipynb new file mode 100644 index 000000000..1cdd7aaee --- /dev/null +++ b/examples/tcn_mimic3.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TCN Model Training on MIMIC-III Dataset\n", + "\n", + "Train the TCN (Temporal Convolutional Networks) model for mortality prediction using the MIMIC-III dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import MIMIC3Dataset\n", + "\n", + "dataset = MIMIC3Dataset(\n", + " root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III\",\n", + " tables=[\"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\"],\n", + " dev=True,\n", + ")\n", + "dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set Mortality Prediction Task\n", + "\n", + "We use the in-hospital mortality prediction task which predicts patient mortality based on conditions and procedures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.tasks import MortalityPredictionMIMIC3\n", + "\n", + "task = MortalityPredictionMIMIC3()\n", + "samples = dataset.set_task(task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split Dataset\n", + "\n", + "Split the dataset into train, validation, and test sets using patient-level splitting." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import split_by_patient, get_dataloader\n", + "\n", + "train_dataset, val_dataset, test_dataset = split_by_patient(\n", + " samples, ratios=[0.7, 0.15, 0.15]\n", + ")\n", + "\n", + "train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)\n", + "val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)\n", + "test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize TCN Model\n", + "\n", + "Create the TCN model with specified hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.models import TCN\n", + "\n", + "model = TCN(\n", + " dataset=samples,\n", + " embedding_dim=128,\n", + " num_channels=128,\n", + " kernel_size=2,\n", + " dropout=0.5,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train Model\n", + "\n", + "Train the model using the PyHealth Trainer with relevant metrics for mortality prediction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.trainer import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " metrics=[\"roc_auc\", \"pr_auc\", \"f1\", \"accuracy\"],\n", + ")\n", + "\n", + "trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " epochs=10,\n", + " monitor=\"roc_auc\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate on Test Set\n", + "\n", + "Evaluate the trained model on the test set and print the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = trainer.evaluate(test_loader)\n", + "\n", + "print(\"Test Set Results:\")\n", + "print(f\" ROC-AUC: {results['roc_auc']:.4f}\")\n", + "print(f\" PR-AUC: {results['pr_auc']:.4f}\")\n", + "print(f\" F1 Score: {results['f1']:.4f}\")\n", + "print(f\" Accuracy: {results['accuracy']:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom TCN Configuration\n", + "\n", + "You can customize the TCN architecture by specifying different parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create TCN with custom architecture\n", + "custom_model = TCN(\n", + " dataset=samples,\n", + " embedding_dim=64,\n", + " num_channels=[64, 128, 256], # List for manual layer specification\n", + " kernel_size=3,\n", + " dropout=0.3,\n", + ")\n", + "\n", + "print(\"Custom TCN architecture:\")\n", + "print(f\"Embedding dim: {custom_model.embedding_dim}\")\n", + "print(f\"Output channels: {custom_model.num_channels}\")\n", + "print(f\"Number of features: {len(custom_model.feature_keys)}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 42e5e4e387ddcafbac908b0b988f594d3bd88a64 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Tue, 9 Dec 2025 11:48:59 -0600 Subject: [PATCH 06/10] Added notebook --- .../{tcn_mimic3.ipynb => tcn_mimic3_codes.ipynb} | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) rename examples/{tcn_mimic3.ipynb => tcn_mimic3_codes.ipynb} (93%) diff --git a/examples/tcn_mimic3.ipynb b/examples/tcn_mimic3_codes.ipynb similarity index 93% rename from examples/tcn_mimic3.ipynb rename to examples/tcn_mimic3_codes.ipynb index 1cdd7aaee..159aba4bb 100644 --- a/examples/tcn_mimic3.ipynb +++ b/examples/tcn_mimic3_codes.ipynb @@ -31,7 +31,7 @@ "source": [ "## Set Mortality Prediction Task\n", "\n", - "We use the in-hospital mortality prediction task which predicts patient mortality based on conditions and procedures." + "We use the in-hospital mortality prediction task which predicts patient mortality based on diagnosis and procedure codes." ] }, { @@ -117,7 +117,7 @@ "\n", "trainer = Trainer(\n", " model=model,\n", - " metrics=[\"roc_auc\", \"pr_auc\", \"f1\", \"accuracy\"],\n", + " metrics=[\"pr_auc\", \"roc_auc\", \"f1\", \"accuracy\"],\n", ")\n", "\n", "trainer.train(\n", @@ -190,16 +190,8 @@ "name": "python3" }, "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.8.0" } }, "nbformat": 4, From ffc94bb4c6aaf5c68cae9717009acdeec6104d88 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Wed, 24 Dec 2025 11:15:33 -0600 Subject: [PATCH 07/10] Added usage example to docstring --- pyhealth/models/tcn.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/pyhealth/models/tcn.py b/pyhealth/models/tcn.py index ebe36f643..58d330ef1 100644 --- a/pyhealth/models/tcn.py +++ b/pyhealth/models/tcn.py @@ -215,6 +215,39 @@ class TCN(BaseModel): If int, depth is auto-computed from max_seq_length. If list, specifies channels for each layer. Default is 128. **kwargs: other parameters for the TCN layer (e.g., max_seq_length, kernel_size, dropout). + + Examples: + >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.models import TCN + >>> from pyhealth.datasets import get_dataloader + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "conditions": ["cond-33", "cond-86", "cond-80"], + ... "procedures": ["proc-12", "proc-45"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-1", + ... "visit_id": "visit-1", + ... "conditions": ["cond-12", "cond-52"], + ... "procedures": ["proc-23"], + ... "label": 0, + ... }, + ... ] + >>> dataset = SampleDataset( + ... samples=samples, + ... input_schema={"conditions": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="test_tcn_dataset", + ... ) + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> model = TCN(dataset=dataset, embedding_dim=64, num_channels=64, max_seq_length=10) + >>> data_batch = next(iter(train_loader)) + >>> ret = model(**data_batch) + >>> print(ret) + """ def __init__( From c62e29c5953cb341072ac335fc774ea7d94f0437 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Wed, 24 Dec 2025 11:27:39 -0600 Subject: [PATCH 08/10] changed SampleDataset to create_sample_dataset in docstring --- pyhealth/models/tcn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyhealth/models/tcn.py b/pyhealth/models/tcn.py index 58d330ef1..14e90ce39 100644 --- a/pyhealth/models/tcn.py +++ b/pyhealth/models/tcn.py @@ -217,9 +217,8 @@ class TCN(BaseModel): **kwargs: other parameters for the TCN layer (e.g., max_seq_length, kernel_size, dropout). Examples: - >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import TCN - >>> from pyhealth.datasets import get_dataloader >>> samples = [ ... { ... "patient_id": "patient-0", @@ -236,7 +235,7 @@ class TCN(BaseModel): ... "label": 0, ... }, ... ] - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "sequence"}, ... output_schema={"label": "binary"}, From 2ddd8f6b8b643b5dd1e220202170ea608a408558 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Wed, 24 Dec 2025 11:47:58 -0600 Subject: [PATCH 09/10] updated docstring --- pyhealth/models/tcn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pyhealth/models/tcn.py b/pyhealth/models/tcn.py index 14e90ce39..3569ea70b 100644 --- a/pyhealth/models/tcn.py +++ b/pyhealth/models/tcn.py @@ -217,7 +217,8 @@ class TCN(BaseModel): **kwargs: other parameters for the TCN layer (e.g., max_seq_length, kernel_size, dropout). Examples: - >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> from pyhealth.datasets import create_sample_dataset + >>> from pyhealth.datasets import get_dataloader >>> from pyhealth.models import TCN >>> samples = [ ... { @@ -246,7 +247,12 @@ class TCN(BaseModel): >>> data_batch = next(iter(train_loader)) >>> ret = model(**data_batch) >>> print(ret) - + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...) + } """ def __init__( From 20e1098dd6f9fe45f8be0d18a5c469839d798520 Mon Sep 17 00:00:00 2001 From: bilalarif Date: Wed, 24 Dec 2025 12:00:55 -0600 Subject: [PATCH 10/10] changed SampleDataset to create_sample_dataset in test_tcn.py --- tests/core/test_tcn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_tcn.py b/tests/core/test_tcn.py index 9fd94f016..a8eee4ea4 100644 --- a/tests/core/test_tcn.py +++ b/tests/core/test_tcn.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets import create_sample_dataset, get_dataloader from pyhealth.models import TCN @@ -35,7 +35,7 @@ def setUp(self): self.output_schema = {"label": "binary"} # binary classification # Create dataset - self.dataset = SampleDataset( + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema,