A PyTorch implementation of the DeFoG model for training and sampling discrete graph flows. (Please update to the latest commit. Recent fixes have been applied.)
Oral Presentation: https://icml.cc/virtual/2025/oral/47238
Working with directed graphs? Consider using DIRECTO, a discrete flow matching framework for directed graph generation.
For an updated development environment with modernized dependencies, see the
updated_envbranch. Themainbranch remains the reference implementation, based on Python 3.9 and older package versions.
We provide two alternative installation methods: Docker and Conda.
We provide the Dockerfile to run DeFoG in a container.
- Build the Docker image:
docker build --platform=linux/amd64 -t defog-image .
pip install -e . to make all the repository modules visible.
- Install Conda (we used version 25.1.1) and create DeFoG's environment:
conda env create -f environment.yaml conda activate defog
- Run the following commands to check if the installation of the main packages was successful:
If you see no errors, the installation was successful and you can proceed to the next step.
python -c "import sys; print('Python version:', sys.version)" python -c "import rdkit; print('RDKit version:', rdkit.__version__)" python -c "import graph_tool as gt; print('Graph-Tool version:', gt.__version__)" python -c "import torch; print(f'PyTorch version: {torch.__version__}, CUDA version (via PyTorch): {torch.version.cuda}')" python -c "import torch_geometric as tg; print('PyTorch Geometric version:', tg.__version__)"
- Compile the ORCA evaluator:
cd src/analysis/orca g++ -O2 -std=c++11 -o orca orca.cpp
All commands use python main.py with Hydra overrides. Note that main.py is inside the src directory.
Use this script to quickly test the code.
python main.py +experiment=debugpython main.py +experiment=<dataset> dataset=<dataset>- QM9 (no H):
+experiment=qm9_no_h dataset=qm9 - Planar:
+experiment=planar dataset=planar - SBM:
+experiment=sbm dataset=sbm - Tree:
+experiment=tree dataset=tree - Comm20:
+experiment=comm20 dataset=comm20 - Guacamol:
+experiment=guacamol dataset=guacamol - MOSES:
+experiment=moses dataset=moses - QM9 (with H):
+experiment=qm9_with_h dataset=qm9 - TLS (conditional):
+experiment=tls dataset=tls - ZINC:
+experiment=zinc dataset=zinc
Sampling from DeFoG is typically done in two steps:
- Sampling Optimization → find best sampling configuration
- Final Sampling → sample and measure performance under the best configuration
To perform 5 runs (mean ± std), set general.num_sample_fold=5.
For the rest of this section, we take Planar dataset as an example:
python main.py +experiment=planar dataset=planar general.test_only=<path/to/checkpoint> sample.eta=0 sample.omega=0 sample.time_distortion=identityNote that if you run:
python main.py +experiment=planar dataset=planar general.test_only=<path/to/checkpoint> it will run with the sampling parameters (η, ω, sample distortion) that we obtained after sampling optimization (see next section) and are reported in the paper.
To search over the optimal inference hyperperameters (η, ω, distortion), use the sample.search flag, which will save a csv file with the results.
- Non-grid search (independent search for each component):
python main.py +experiment=planar dataset=planar general.test_only=<path/to/checkpoint> sample.search=all
- Component-wise: set
sample.search=target_guidance | distortion | stochasticityabove.
Use optimal η, ω, time distortion resulting from the search:
python main.py +experiment=planar dataset=planar general.test_only=<path/to/checkpoint> sample.eta=<η> sample.omega=<ω> sample.time_distortion=<distortion>Start by creating a new file in the src/datasets directory. You can refer to the following scripts as examples:
spectre_dataset.py, if you are using unattributed graphs;tls_dataset.py, if you are using graphs with attributed nodes;qm9_dataset.pyorguacamol_dataset.py, if you are using graphs with attributed nodes and edges (e.g., molecular data).
This new file should define a Dataset class to handle data processing (refer to the PyG documentation for guidance), as well as a DatasetInfos class to specify relevant dataset properties (e.g., number of nodes, edges, etc.).
Once your dataset file is ready, update main.py to incorporate the new dataset. Additionally, you can add a corresponding file in the configs/dataset directory.
Finally, if you are planning to introduce custom metrics, you can create a new file under the metrics directory.
Checkpoints, along with their corresponding results and generated samples, are shared here.
To run sampling and evaluate generation with a given checkpoint, set the general.test_only flag to the path of the checkpoint file (.ckpt file). To skip sampling and directly evaluate previously generated samples, set the flag general.generated_path to the path of the generated samples (.pkl file).
(Note: The released checkpoints are retrained models from the public repository. Their performance is consistent with the paper’s findings, with minor variations attributable to training/sampling stochasticity.)
- protein / EGO datasets
- FCD score for molecules
- W&B sweeps for sampling optimization
- DiGress: https://github.com/cvignac/DiGress
- Discrete Flow Models: https://github.com/andrew-cr/discrete_flow_models
@inproceedings{qinmadeira2024defog,
title = {DeFoG: Discrete Flow Matching for Graph Generation},
author = {Qin, Yiming and Madeira, Manuel and Thanou, Dorina and Frossard, Pascal},
booktitle = {International Conference on Machine Learning (ICML)},
year = {2025},
}

