Skip to content

MLAI-Yonsei/FANS

Repository files navigation

FANS: Function and Noise Separation Framework

Overview

FANS (Function and Noise Separation) is a unified framework for detecting and dissecting causal mechanism shifts in Structural Causal Models (SCMs). Unlike existing methods limited to additive noise models, FANS handles non-additive, non-linear SCMs and distinguishes between:

  • Function shifts: Changes in causal mechanisms
  • Noise shifts: Changes in noise distributions

This dissection capability is critical for applications in biomedical science, manufacturing, and other domains where understanding the root cause of distributional changes is essential.

Key Features

  • Shift Detection: Identifies which variables have undergone distributional shifts between environments
  • Shift Dissection: Distinguishes function shifts from noise shifts using an independence criterion
  • Non-linear SCM Support: Works beyond additive noise models
  • No Retraining Required: Two-stage algorithm efficiently analyzes shifts without model retraining
  • Simultaneous Shift Handling: Addresses the complex case of concurrent function and noise shifts

Method

Theoretical Foundation

FANS is grounded in a theoretical independence criterion:

Function shifts induce a statistical dependence between a node's parents and its residual noise.

This key insight enables distinguishing function shifts from noise shifts by testing the independence between parent variables and estimated noise.

Two-Stage Algorithm

  1. Detection Stage: Transform data through learned causal normalizing flow and identify shifted variables via conditional distribution comparison
  2. Dissection Stage: For detected shifts, test independence between parents and residual noise to classify shift type

Installation

Prerequisites

Create a new conda environment with Python 3.9.12:

conda create --name fans python=3.9.12 --no-default-packages

Activate the conda environment:

conda activate fans

Install Dependencies

Install PyTorch and related packages:

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117

Install additional requirements:

pip install -r requirements.txt

Quick Start

Train a FANS Model

Train a FANS model on a synthetic dataset with 10 nodes and Erdős-Rényi (ER) graph structure:

CUDA_VISIBLE_DEVICES=0 python main.py \
    --config_file causal_nf/configs/data_small/nodes_10/ER/causal_nf_nodes_10_ER_adj_1.yaml \
    --wandb_mode offline \
    --project causal_nf

What this does:

  • Trains a causal normalizing flow on environment 1 data
  • Evaluates shift detection performance on environment 2 data
  • Saves results to results/ directory
  • Generates visualizations of detected shifts

FANS Method

The FANS (Flow-based Analysis of Noise Shift) method leverages trained causal normalizing flows to detect and classify distributional shifts between two environments.

Core Methodology

  1. Training: Learn a causal normalizing flow on environment 1 data that maps observations X to noise variables Z following the causal graph structure

  2. Shift Detection: Transform environment 2 data through the learned flow and test for independence violations in the noise space

  3. Statistical Testing: Use distance correlation and independence tests to identify shifted variables

  4. Visualization: Generate comparative plots showing distributional differences

Data Structure

Synthetic Data

Synthetic datasets are organized by node count and graph type:

data/data_small/
├── nodes_10/
│   ├── ER/          # Erdős-Rényi random graphs
│   │   ├── adj_1.npy           # Adjacency matrix
│   │   ├── data_env1_1.npy     # Environment 1 data
│   │   ├── data_env2_1.npy     # Environment 2 data
│   │   └── metadata_1.json     # Shift information
│   └── SF/          # Scale-free graphs
├── nodes_20/
├── nodes_30/
├── nodes_40/
└── nodes_50/

Real Datasets

  • Morpho-MNIST: Located in data/morpho_mnist/
  • Sachs: Located in data/sachs/

Running Experiments

Training FANS Model

Basic Usage

python main.py \
    --config_file <CONFIG_PATH> \
    --wandb_mode <MODE> \
    --project <PROJECT_NAME>

Example: Train on 30-node scale-free graph

CUDA_VISIBLE_DEVICES=1 python main.py \
    --config_file causal_nf/configs/data_small/nodes_30/SF/causal_nf_nodes_30_SF_adj_5.yaml \
    --wandb_mode online \
    --project fans_experiments

Running Baseline Methods

Run baseline shift detection methods for comparison:

python experiments/experiment_script.py --model <MODEL_NAME> [OPTIONS]

Available Models:

  • splitkci: Kernel Conditional Independence Test
  • prediter: PreDITEr method
  • iscan: Independence-based shift detection
  • linearccp: Linear CCP
  • gpr: Gaussian Process Regression

Options:

Option Description Default
--nodes Node counts to process (space-separated) 10 20 30 40 50
--gpu GPU device ID (-1 for CPU) -1
--output_dir Results directory auto-generated
--config_type Graph type filter (ER, SF, all) all
--dataset_indices Dataset range (e.g., "1-30") all

Examples

Run SplitKCI on all node sizes, only first 5 datasets:

python experiments/experiment_script.py \
    --model splitkci \
    --dataset_indices "1-5" \
    --gpu 0

Run ISCAN on CPU for SF graphs:

python experiments/experiment_script.py \
    --model iscan \
    --config_type SF \
    --gpu -1

Results and Analysis

Analysis

Generate unified results CSV:

python experiments/analysis/analysis.py

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published