This repository is the official implementation of "CCL: Causal-aware In-context Learning for Out-of-Distribution Generalization".
In this study, we focus on constructing a robust demonstration set to enhance the generalization of LLMs in OOD scenarios. Inspired by CRL, we propose a novel demonstration selection method, causal-aware in-context learning (CCL), which learns causal representations that remain invariant across environments and prioritizes candidates by assigning higher ranks to those with causal representations similar to the target query. Under the causal mechanism, we theoretically demonstrate that the demonstration set selected by CCL comprises candidates that are more closely related to the underlying problem addressed by the target query, rather than merely matching its context. The problem-level invariance of CCL ensures generalization performance for the target query even in unseen environments. We empirically validate that CCL operates robustly in OOD scenarios and demonstrates superior generalization performance on both synthetic and real datasets.
To install requirements:
pip install -r requirements.txt
MGSM dataset OOD NLP dataset (the same datasets used for training and evaluating LLM-R).
To train CCL, please refer here to get detailed commond scripts
To conduct few-shot or zero-shot prompt generation, please refer here to get detailed commond scripts
| Method | Total | ID | OOD |
|---|---|---|---|
| ZS | 87.71 | 89.43 | 84.70 |
| ICL (Fix.) | 91.20 | 91.26 | 91.10 |
| ICL (KNN) | 94.07 | 95.83 | 91.00 |
| CCL | 94.55 | 96.11 | 91.80 |
| Language Model | Retrieval Method | QNLI | PIQA | WSC273 | YELP | Avg. |
|---|---|---|---|---|---|---|
| Llama 3.2 3B IT | ZS | 43.36 | 71.33 | 55.31 | 88.98 | 64.75 |
| LLM-R | 29.93 | 69.91 | 61.17 | 79.48 | 60.12 | |
| ICL (K-means) | 68.13 | 69.04 | 49.82 | 75.81 | 65.70 | |
| CCL | 75.18 | 70.46 | 61.91 | 95.44 | 75.74 | |
| Phi-4 mini IT | ZS | 86.34 | 76.01 | 64.10 | 95.76 | 80.55 |
| LLM-R | 85.21 | 74.10 | 65.93 | 96.37 | 80.40 | |
| ICL (K-means) | 83.18 | 74.81 | 71.06 | 96.25 | 81.33 | |
| CCL | 82.26 | 75.73 | 71.43 | 96.33 | 81.44 | |
| GPT-4o | ZS | 91.30 | 94.07 | 90.84 | 97.47 | 93.42 |
| LLM-R | 90.32 | 94.23 | 92.67 | 98.27 | 93.87 | |
| ICL (K-means) | 88.28 | 93.04 | 87.55 | 98.17 | 91.76 | |
| CCL | 90.77 | 93.15 | 93.77 | 98.36 | 94.01 |
Sentence embedding model: OpenAI’s text-embedding-3-small
The CCL code is based on this repository.