Skip to content

csiro/synthetic-ehr-pipeline

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

92 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Synthetic EHR Pipeline

Overview

This repository provides a pipeline for generating realistic synthetic electronic health record (EHR) data using HALO or ConSequence as the core model engine. This implementation builds on code from btheodorou/KnowledgeInfusion. We are grateful to the original HALO and ConSequence authors for developing these models and making their work publicly available.

The pipeline wraps these models in a modular preprocessing, training, and post-processing framework designed to:

  • simplify integration and extension,
  • improve reproducibility across datasets, and
  • enhance realism of continuous and temporal variables through KDE-based post-processing.

It is applied to three clinically diverse datasets to demonstrate its robustness across different scales and data structures.

The typical workflow is:

  • Preprocess raw data into the expected format (01_preprocessing.py).
  • Train the model (02_training.py) using HALO or ConSequence.
  • Generate synthetic samples (03_sampling.py).
  • Apply KDE post-processing (04_kde_preprocessing.py / 05_kde_sampling.py) to reconstruct continuous variables and timestamps.

1. Setup

conda env create -f environment.yml
conda activate synthetic-ehr-pipeline

2. Execution environments

The pipeline can be run in different environments, such as your local machine, an HPC cluster, or Kubeflow, depending on the configuration specified in each dataset’s config.yml.

For the ADNI and MIMIC-III datasets, models were trained using two NVIDIA V100 GPUs, whereas the Queensland Health model was trained using four NVIDIA H100 GPUs. For the Queensland Health dataset, synthetic data generation was parallelised across multiple nodes with each node using a different random seed to generate independent synthetic datasets.

3. Repository structure

Each dataset folder contains an end-to-end workflow using the shared core modules.

synthetic-ehr-pipeline/
│
├── core/ # Generic, reusable components
│ ├── data_processing.py # Input preparation
│ ├── generation.py # Synthetic data generation
│ ├── kde.py # KDE-based post-processing utilities
│ ├── models.py # Adapted HALO / ConSequence model
│ ├── rules.py # Rules generation for ConSequence
│ ├── training.py # Training loop
│ └── utils.py # Helper functions
│
├── adni/ # ADNI-specific pipeline
│ ├── 01_preprocessing.py
│ ├── 02_training.py
│ ├── 03_sampling.py
│ ├── 04_kde_preprocessing.py
│ ├── 05_kde_sampling.py
│ ├── config.py
│ ├── config.yml
│ ├── data_processing.py
│ ├── generation.py
│ └── kde.py
│
├── mimic/ # MIMIC-III-specific pipeline
│ ├── 01_preprocessing.py
│ ├── 02_training.py
│ ├── 03_sampling.py
│ ├── 04_kde_preprocessing.py
│ ├── 05_kde_sampling.py
│ ├── config.py
│ ├── config.yml
│ ├── data_processing.py
│ ├── generation.py
│ └── kde.py
│
└── queensland_health/ # Queensland Health-specific pipeline
├── 01_preprocessing.py
├── 02_training.py
├── 03_sampling.py
├── 03b_conditional_sampling.py
├── 04_kde_preprocessing.py
├── 05_kde_sampling.py
├── config.py
├── config.yml
├── data_processing.py
├── generation.py
└── kde.py

The core/ directory contains the reusable building blocks used by all datasets:

  • data_processing.py handles input loading, cleaning, and feature engineering.
  • training.py defines the training loop.
  • models.py defines the adapted HALO/ConSequence architecture.
  • rules.py defines functions used to generate rules for ConSequence models.
  • generation.py manages sampling from the trained model.
  • kde.py implements post-processing for reconstructing continuous variables.
  • utils.py includes utility functions such as for file and directory management and reading in data.

These modules are dataset-agnostic and can be reused across datasets or extended for new projects.

4. Model implementation

The pipeline uses HALO and ConSequence as the underlying model engines.

Technical adaptations

The HALO and ConSequence model implementations were adapted to simplify integration and improve reproducibility:

  • Replaced custom Conv1D layers with standard linear layers
  • Generated causal masks dynamically at runtime
  • Modified forward pass to return raw logits (for BCEWithLogitsLoss)
  • Added optional mixed-precision training for modern GPUs

These changes improve numerical stability, reproducibility, and compatibility with privacy libraries such as Opacus.

Additionally, an EHRParquetDataset class was implemented with PyTorch DataLoader for efficient batching and multi-worker data loading.

Code attribution

Some code was adapted from the repository associated with the ConSequence manuscript. Specifically:

  • core/data_processing.py includes modified code from evaluate_generationSpeed_base.py and genDataset.py
  • core/models.py includes modified code from model.py and consequenceModel.py
  • core/rules.py includes modified code from find_rules.py
  • core/generation.py includes modified code from evaluate_generationSpeed_base.py
  • core/utils.py includes modified code from train_model.py

5. Data sources

This repository includes code for generating synthetic data from three datasets:

  • The Alzheimer's Disease Neuroimaging Initiative (ADNI): A publicly available dataset from an ongoing longitudinal study including ~2.4k participants.
  • MIMIC-III: A publicly available dataset with data from ~46k patients ICU patients.
  • Queensland Health: A dataset including ED and inpatient encounters from ~3.5M patients in Queensland hospitals.

To run the code for the ADNI and MIMIC-III datasets, you will need apply for access to these datasets and then download the relevant files: - ADNI: ADNIMERGE.csv - MIMIC-III: ADMISSIONS.csv.gz, PATIENTS.csv.gz, DIAGNOSES_ICD.csv, PROCEDURES_ICD.csv.

While the Queensland Health dataset is not publicly available, the code is included for completeness and to illustrate how the pipeline works on a relatively large dataset.

6. Dataset configuration

Each dataset folder (e.g. adni/,mimic/, queensland_health/) includes its own configuration files that control both paths and pipeline behaviour.

config.yml defines environment-specific paths (data directories, output locations) that must be configured for your system before running the pipeline. Examples are provided.

config.py provides Python-level constants and variable definitions for dataset-specific logic:

  • Defines which variables are modelled with KDE, their stratification, filters, and bandwidths.
  • May also include dataset-specific column mappings used by the core modules.

Each dataset directory (e.g. adni/, mimic/, queensland_health/) includes dataset-specific scripts and helper functions that adapt the shared core pipeline to that dataset’s structure and schema. The modules in core/ provide all common functionality, while dataset-specific code remains self-contained within each folder. New datasets can be incorporated by creating a new directory with the required configuration files and any additional custom functions.

7. KDE-based post-processing

To improve the utility of the synthetic data, the pipeline applies a kernel density estimation (KDE) step after data generation. The KDE functionality is implemented in core/kde.py, which provides functions for generating probability grids (generate_kde_grids) and sampling from them (sample_from_kde_grid).

The KDE_CONFIG dictionary in the dataset-specific config.py controls how KDE is used to model distributions of variables in your data. This configuration drives the generate_kde_grids() function to create probability grids that can later be sampled from during synthetic data generation.

The KDE configuration must be adapted for each new dataset. Variables differ in scale, distribution shape, and correlations, so you’ll need to determine which variables should use KDE smoothing, whether to stratify by demographic or clinical covariates, and suitable grid parameters and bandwidths for each variable or variable pair.

KDE_CONFIG structure

KDE_CONFIG = {
    "num_cores": -1,            # Number of CPU cores (-1 = all available cores)
    "variable_config": {        # Dictionary of variable configurations
        "variable_name": { ... },
        ...
    }
}

Each variable in variable_config has the following structure:

Field Type Description
method string KDE method to use: "kde_range", "kde_joint", or "value_counts"
data_source string Data source: "patients" or "visits"

The following are optional fields, depending on the method:

Field Type Description
stratify_by list of strings Column names to stratify the KDE by (creates separate distributions)
filter dict Row filters to apply before computing KDE
grid_params dict Grid generation parameters (required for KDE methods)
bandwidth list of floats KDE bandwidth parameters (required for KDE methods)

Methods

1. kde_range: Single variable KDE

Use this for continuous variables where you want to model the full distribution.

Required fields:

  • grid_params: Parameters for creating the evaluation grid
  • bandwidth: Single-element list with KDE bandwidth

Example:

"age": {
    "method": "kde_range",
    "data_source": "patients",
    "stratify_by": ["gender", "event_type"],
    "grid_params": {
        "method": "range",
        "use_data_bounds": True,
        "step": 1,
        "offset": 2,
    },
    "bandwidth": [0.25],
}

2. kde_joint: Multivariate KDE

Use this for modeling joint distributions of two variables.

Required fields:

  • grid_params: Dictionary with grid parameters for each variable
  • bandwidth: List with one bandwidth per variable

Example:

"los_triage_joint": {
    "method": "kde_joint",
    "data_source": "visits",
    "filter": {"event_type": "ED"},
    "stratify_by": ["triage_category"],
    "grid_params": {
        "los": {
            "method": "fine_bins",
            "binned_variable": "los_binned",
            "filter": {"threshold": 24 * 56, "modulo": 24},
        },
        "los_triage": {
            "method": "fine_bins",
            "binned_variable": "los_triage_binned",
        },
    },
    "bandwidth": [0.5, 0.5],
}

3. value_counts: Discrete distribution

Use this for categorical variables or when you want exact empirical frequencies.

Required fields: None (only method and data_source)

Example:

"event_day": {
    "method": "value_counts",
    "data_source": "visits",
    "stratify_by": ["event_type"],
}

Grid parameters (grid_params)

Grid parameters control how the grid is constructed for KDE methods.

1. Grid Method: range

Creates integer grid using Python's range().

"grid_params": {
    "method": "range",
    "min": 0,                    # Optional: minimum value (default: data min)
    "max": 100,                  # Optional: maximum value (default: data max)
    "use_data_bounds": True,     # If True, use data min/max
    "step": 1,                   # Step size
    "offset": 2,                 # Added to max value
}

2. Grid Method: arange

Creates float grid using NumPy's arange().

"grid_params": {
    "method": "arange",
    "min": 0.0,                  # Optional: minimum value
    "max": 10.0,                 # Optional: maximum value
    "use_data_bounds": True,     # If True, use data min/max
    "step": 0.5,                 # Step size (can be float)
    "offset": 1,                 # Added to max value
}

3. Grid Method: fine_bins

Creates fine-grained grid from existing bin strings (e.g., "1-2", "2.5-3.5").

"grid_params": {
    "method": "fine_bins",
    "binned_variable": "los_binned",  # Name of column containing bin strings
    "filter": {                        # Optional: filter grid values
        "threshold": 1344,              # 56 days * 24 hours
        "modulo": 24,                   # Keep only daily intervals above threshold
    }
}

Note that you can filter the grid itself by specifying a threshold and modulo.

4. Grid Method: custom

Provide explicit list of grid points.

"grid_params": {
    "method": "custom",
    "cut_points": [0, 1, 5, 10, 20, 50, 100]
}

Other features

  • Stratification (stratify_by): Stratification creates separate KDEs for each unique combination of stratification variables.
  • Filtering (filter): Filters apply row-level restrictions before computing KDE (e.g., to include only rows where event_type is "Inpatient", where age is non-null, or event_type is a value in a supplied list).

Outputs

Running generate_kde_grids() with this configuration creates parquet files:

{write_dir_model}/grid_{variable_name}_kde.parquet

Each file contains:

  • Grid points (with _start and _end columns for ranges)
  • Probability or range_probability column
  • Stratification columns (if specified)

These grids are later used by sample_from_kde_grid() during synthetic data generation.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published