|
| 1 | +import datetime |
| 2 | +from sqlite3 import Timestamp |
| 3 | +import unittest |
| 4 | +from pyhealth.data.data import Event |
| 5 | +import pandas |
| 6 | + |
| 7 | +from pyhealth.datasets import eICUDataset |
| 8 | +from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion |
| 9 | + |
| 10 | + |
| 11 | +class TesteICUDataset(unittest.TestCase): |
| 12 | + |
| 13 | + # to test the file this path needs to be updated |
| 14 | + DATASET_NAME = "eICU-demo" |
| 15 | + ROOT = "https://storage.googleapis.com/pyhealth/eicu-demo/" |
| 16 | + TABLES = ["diagnosis", "medication", "lab", "treatment", "physicalExam"] |
| 17 | + CODE_MAPPING = {} |
| 18 | + DEV = True # not needed when using demo set since its 100 patients large |
| 19 | + REFRESH_CACHE = True |
| 20 | + |
| 21 | + dataset = eICUDataset( |
| 22 | + dataset_name=DATASET_NAME, |
| 23 | + root=ROOT, |
| 24 | + tables=TABLES, |
| 25 | + code_mapping=CODE_MAPPING, |
| 26 | + dev=DEV, |
| 27 | + refresh_cache=REFRESH_CACHE, |
| 28 | + ) |
| 29 | + |
| 30 | + def setUp(self): |
| 31 | + pass |
| 32 | + |
| 33 | + def test_patient(self): |
| 34 | + # given parametes: |
| 35 | + selected_patient_id = "002-10009+193705" |
| 36 | + selected_visit_index = 0 |
| 37 | + # selected indeces for events defined in `expected_event_data` |
| 38 | + |
| 39 | + # expect: |
| 40 | + # patient data |
| 41 | + expected_birth_datetime = pandas.Timestamp("1938-02-24 00:00:00") |
| 42 | + expected_death_datetime = None |
| 43 | + expected_ethnicity = "Caucasian" |
| 44 | + expected_gender = "Female" |
| 45 | + |
| 46 | + # visit data |
| 47 | + expected_visit_len = 1 |
| 48 | + expected_visit_id = "224606" |
| 49 | + expected_visit_discharge_status = "Alive" |
| 50 | + expected_discharge_time = datetime.datetime(2014, 2, 27, 0, 45) |
| 51 | + expected_encounter_time = datetime.datetime(2014, 2, 24, 2, 59) |
| 52 | + |
| 53 | + # visit attribute dict data |
| 54 | + expected_visit_attr_dict_len = 2 |
| 55 | + expected_visit_hopital_id = 71 |
| 56 | + expected_visit_region = "Midwest" |
| 57 | + |
| 58 | + # event level data |
| 59 | + expected_event_count = 319 |
| 60 | + |
| 61 | + # during a specified visit assert the event data is correct. Event data is parametrized by tables |
| 62 | + # schema: |
| 63 | + # event type (from one of the requested tables) |
| 64 | + # 'length': number of events for that event type |
| 65 | + # 'events': |
| 66 | + # tuple of index of the event in the event array, pyhealth.Event object with hardcoded relevant fields for event at the index |
| 67 | + expected_event_data = { |
| 68 | + "diagnosis": { |
| 69 | + "length": 8, |
| 70 | + "events": [ |
| 71 | + ( |
| 72 | + 0, |
| 73 | + Event( |
| 74 | + code="567.9", |
| 75 | + timestamp=pandas.Timestamp("2014-02-24 03:36:00"), |
| 76 | + vocabulary="ICD9CM", |
| 77 | + ), |
| 78 | + ), |
| 79 | + ( |
| 80 | + 1, |
| 81 | + Event( |
| 82 | + code="K65.0", |
| 83 | + timestamp=pandas.Timestamp("2014-02-24 03:36:00"), |
| 84 | + vocabulary="ICD10CM", |
| 85 | + ), |
| 86 | + ), |
| 87 | + ], |
| 88 | + }, |
| 89 | + "medication": { |
| 90 | + "length": 38, |
| 91 | + "events": [ |
| 92 | + ( |
| 93 | + 0, |
| 94 | + Event( |
| 95 | + code="MORPHINE INJ", |
| 96 | + timestamp=pandas.Timestamp("2014-02-23 21:09:00"), |
| 97 | + vocabulary="eICU_DRUGNAME", |
| 98 | + ), |
| 99 | + ), |
| 100 | + ( |
| 101 | + 5, |
| 102 | + Event( |
| 103 | + code="CIPROFLOXACIN IN D5W 400 MG/200ML IV SOLN", |
| 104 | + timestamp=pandas.Timestamp("2014-02-23 22:43:00"), |
| 105 | + vocabulary="eICU_DRUGNAME", |
| 106 | + ), |
| 107 | + ), |
| 108 | + ], |
| 109 | + }, |
| 110 | + "lab": { |
| 111 | + "length": 251, |
| 112 | + "events": [ |
| 113 | + ( |
| 114 | + 0, |
| 115 | + Event( |
| 116 | + code="sodium", |
| 117 | + timestamp=pandas.Timestamp("2014-02-23 21:04:00"), |
| 118 | + vocabulary="eICU_LABNAME", |
| 119 | + ), |
| 120 | + ), |
| 121 | + ( |
| 122 | + 2, |
| 123 | + Event( |
| 124 | + code="BUN", |
| 125 | + timestamp=pandas.Timestamp("2014-02-23 21:04:00"), |
| 126 | + vocabulary="eICU_LABNAME", |
| 127 | + ), |
| 128 | + ), |
| 129 | + ], |
| 130 | + }, |
| 131 | + "physicalExam": { |
| 132 | + "length": 22, |
| 133 | + "events": [ |
| 134 | + ( |
| 135 | + 0, |
| 136 | + Event( |
| 137 | + code="notes/Progress Notes/Physical Exam/Physical Exam/Neurologic/GCS/Score/scored", |
| 138 | + timestamp=pandas.Timestamp("2014-02-24 03:05:00"), |
| 139 | + vocabulary="eICU_PHYSICALEXAMPATH", |
| 140 | + ), |
| 141 | + ), |
| 142 | + ( |
| 143 | + 1, |
| 144 | + Event( |
| 145 | + code="notes/Progress Notes/Physical Exam/Physical Exam Obtain Options/Performed - Structured", |
| 146 | + timestamp=pandas.Timestamp("2014-02-24 03:05:00"), |
| 147 | + vocabulary="eICU_PHYSICALEXAMPATH", |
| 148 | + ), |
| 149 | + ), |
| 150 | + ], |
| 151 | + }, |
| 152 | + } |
| 153 | + |
| 154 | + # patient level information |
| 155 | + actual_patient = self.dataset.patients[selected_patient_id] |
| 156 | + self.assertEqual(expected_visit_len, len(actual_patient.visits)) |
| 157 | + self.assertEqual(expected_birth_datetime, actual_patient.birth_datetime) |
| 158 | + self.assertEqual(expected_death_datetime, actual_patient.death_datetime) |
| 159 | + self.assertEqual(expected_ethnicity, actual_patient.ethnicity) |
| 160 | + self.assertEqual(expected_gender, actual_patient.gender) |
| 161 | + |
| 162 | + # visit level information |
| 163 | + actual_visit_id = actual_patient.index_to_visit_id[selected_visit_index] |
| 164 | + self.assertEqual(expected_visit_id, actual_visit_id) |
| 165 | + |
| 166 | + actual_visit = actual_patient.visits[actual_visit_id] |
| 167 | + self.assertEqual(expected_event_count, actual_visit.num_events) |
| 168 | + self.assertEqual(expected_visit_discharge_status, actual_visit.discharge_status) |
| 169 | + self.assertEqual(expected_discharge_time, actual_visit.discharge_time) |
| 170 | + self.assertEqual(expected_encounter_time, actual_visit.encounter_time) |
| 171 | + |
| 172 | + # visit attributes |
| 173 | + actual_visit_attributes = actual_visit.attr_dict |
| 174 | + self.assertEqual(expected_visit_attr_dict_len, len(actual_visit_attributes)) |
| 175 | + self.assertEqual( |
| 176 | + expected_visit_hopital_id, actual_visit_attributes["hospital_id"] |
| 177 | + ) |
| 178 | + self.assertEqual(expected_visit_region, actual_visit_attributes["region"]) |
| 179 | + |
| 180 | + # event level information |
| 181 | + actual_event_list_dict = actual_visit.event_list_dict |
| 182 | + for event_key in expected_event_data: |
| 183 | + actual_event_array = actual_event_list_dict[event_key] |
| 184 | + expected_event = expected_event_data[event_key] |
| 185 | + |
| 186 | + self.assertEqual( |
| 187 | + expected_event["length"], |
| 188 | + len(actual_event_array), |
| 189 | + f"incorrect num events for'{event_key}'", |
| 190 | + ) |
| 191 | + for selected_index, expected_pyhealth_Event in expected_event["events"]: |
| 192 | + error_message = f"incorrect event code on '{event_key}' event, selected index: {selected_index}" |
| 193 | + |
| 194 | + actual_event = actual_event_array[selected_index] |
| 195 | + self.assertEqual( |
| 196 | + expected_pyhealth_Event.code, actual_event.code, error_message |
| 197 | + ) |
| 198 | + self.assertEqual( |
| 199 | + expected_pyhealth_Event.timestamp, |
| 200 | + actual_event.timestamp, |
| 201 | + error_message, |
| 202 | + ) |
| 203 | + self.assertEqual( |
| 204 | + expected_pyhealth_Event.vocabulary, |
| 205 | + actual_event.vocabulary, |
| 206 | + error_message, |
| 207 | + ) |
| 208 | + |
| 209 | + def test_statistics(self): |
| 210 | + # self.dataset.stat() |
| 211 | + |
| 212 | + self.assertEqual(sorted(self.TABLES), sorted(self.dataset.available_tables)) |
| 213 | + |
| 214 | + EHRDatasetStatAssertion(self.dataset, 0.01).assertEHRStats( |
| 215 | + expected_num_patients=2174, |
| 216 | + expected_num_visits=2520, |
| 217 | + expected_num_visits_per_patient=1.1592, |
| 218 | + expected_events_per_visit_per_table=[ |
| 219 | + 16.7202, |
| 220 | + 17.8345, |
| 221 | + 172.4841, |
| 222 | + 15.1944, |
| 223 | + 33.3563, |
| 224 | + ], |
| 225 | + ) |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + unittest.main(verbosity=2) |
0 commit comments