Skip to content

MLAI-Yonsei/causal-context-learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

CCL: Causal-aware In-context Learning for Out-of-Distribution Generalization

This repository is the official implementation of "CCL: Causal-aware In-context Learning for Out-of-Distribution Generalization".

In this study, we focus on constructing a robust demonstration set to enhance the generalization of LLMs in OOD scenarios. Inspired by CRL, we propose a novel demonstration selection method, causal-aware in-context learning (CCL), which learns causal representations that remain invariant across environments and prioritizes candidates by assigning higher ranks to those with causal representations similar to the target query. Under the causal mechanism, we theoretically demonstrate that the demonstration set selected by CCL comprises candidates that are more closely related to the underlying problem addressed by the target query, rather than merely matching its context. The problem-level invariance of CCL ensures generalization performance for the target query even in unseen environments. We empirically validate that CCL operates robustly in OOD scenarios and demonstrates superior generalization performance on both synthetic and real datasets.

Requirements

To install requirements:

pip install -r requirements.txt

Dataset

MGSM dataset OOD NLP dataset (the same datasets used for training and evaluating LLM-R).

Training CCL

To train CCL, please refer here to get detailed commond scripts

Generating input prompt & evaluating with LMs

To conduct few-shot or zero-shot prompt generation, please refer here to get detailed commond scripts

Results

MGSM

Method Total ID OOD
ZS 87.71 89.43 84.70
ICL (Fix.) 91.20 91.26 91.10
ICL (KNN) 94.07 95.83 91.00
CCL 94.55 96.11 91.80

OOD NLP

Language Model Retrieval Method QNLI PIQA WSC273 YELP Avg.
Llama 3.2 3B IT ZS 43.36 71.33 55.31 88.98 64.75
LLM-R 29.93 69.91 61.17 79.48 60.12
ICL (K-means) 68.13 69.04 49.82 75.81 65.70
CCL 75.18 70.46 61.91 95.44 75.74
Phi-4 mini IT ZS 86.34 76.01 64.10 95.76 80.55
LLM-R 85.21 74.10 65.93 96.37 80.40
ICL (K-means) 83.18 74.81 71.06 96.25 81.33
CCL 82.26 75.73 71.43 96.33 81.44
GPT-4o ZS 91.30 94.07 90.84 97.47 93.42
LLM-R 90.32 94.23 92.67 98.27 93.87
ICL (K-means) 88.28 93.04 87.55 98.17 91.76
CCL 90.77 93.15 93.77 98.36 94.01

Sentence embedding model: OpenAI’s text-embedding-3-small

Base code of this repository.

The CCL code is based on this repository.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published