Generalized Contrastive Alignment (GCA) provides a robust framework for self-supervised learning tasks, supporting various datasets and augmentation methods.
This is the repo for the paper:
Your Contrastive Learning Problem is Secretly a Distribution Alignment Problem
To set up the required environment, follow these steps:
# Create and activate the environment
conda create -n GCA python=3.11.9 -y
conda activate GCA
# Install dependencies
pip install hydra-core numpy==1.26.4 matplotlib seaborn scikit-image scikit-learn \
pytorch-lightning==1.9.5 torch==2.2.1 torchaudio==2.2.1 \
torchmetrics==1.4.2 torchvision==0.17.1- SimCLR CIFAR10 Implementation by Damrich et al.: GitHub Link
- SimCLR by Ting Chen et al.: GitHub Link
- IOT in Liangliang Shi, et al. "Understanding and generalizing contrastive learning from the inverse optimal transport perspective." ICML, 2023.
The framework supports the following tasks:
simclrhs_incegca_incerincegca_rincegca_uot
The following datasets are supported:
SVHNimagenet100cifar100cifar10
You can configure data augmentation using the strong_DA option:
None(standard augmentation)large_erasebrightnessstrong_crop
To pretrain a model using self-supervised learning, run the following script:
python ssl_pretrain.py \
--config-name "simclr_cifar10.yaml" \
--config-path "./config/" \
task=gca_uot \
dataset=CIFAR10 \
dataset_dir="./datasets" \
batch_size=512 \
seed=32 \
backbone=resnet18 \
projection_dim=128 \
strong_DA=None \
gpus=1 \
workers=16 \
optimizer='SGD' \
learning_rate=0.03 \
momentum=0.9 \
weight_decay=1e-6 \
lam=0.01 \
q=0.6 \
max_epochs=500 \
r1=1 \
r2=0.2To evaluate the pretrained model with a linear classifier, use the following script:
python linear_evaluation.py \
--config-name="simclr_cifar10.yaml" \
--config-path="./config/" \
task=gca_uot \
dataset=cifar10 \
batch_size=512 \
seed=64 \
backbone=resnet18 \
projection_dim=128 \
strong_DA=None \
lam=0.01 \
q=0.6 \
load_epoch=500- Task: Specify the self-supervised learning task (e.g.,
gca_uot). - Dataset: Choose from supported datasets (e.g.,
cifar10). - Data Augmentation: Use
strong_DAto set augmentation type. - Training Parameters:
batch_size: Batch size for training.backbone: Backbone architecture (e.g.,resnet18).projection_dim: Dimension of the projection head.lamandq: Regularization and scaling parameters.max_epochs: Maximum number of epochs for training.
- Ensure that the
dataset_dircontains the datasets in the correct structure. - Customize parameters in the scripts to fit your experimental needs. As an example for SVHN,
- We modify the linear_evaluation to make it able to run on new pytorch version 2.9.0.
python ssl_pretrain.py \
--config-name "simclr_svhn.yaml" \
--config-path "./config/" \
task=gca_uot \
dataset=SVHN \
dataset_dir="./datasets" \
batch_size=512 \
seed=48 \
backbone=resnet18 \
projection_dim=128 \
strong_DA=None \
gpus=1 \
workers=16 \
optimizer='Adam' \
learning_rate=0.03 \
momentum=0.9 \
weight_decay=1e-6 \
lam=0.01 \
q=0.6 \
max_epochs=500 \
r1=1 \
r2=0.01If you find this repository helpful for your research, please cite our paper:
@article{chen2025your,
title={Your contrastive learning problem is secretly a distribution alignment problem},
author={Chen, Zihao and Lin, Chi-Heng and Liu, Ran and Xiao, Jingyun and Dyer, Eva},
journal={Advances in Neural Information Processing Systems},
volume={37},
pages={91597--91617},
year={2025}
}