Skip to content

Conversation

@Wert1996
Copy link

Summary

This PR fixes out-of-memory issues when training SetFit on large datasets by replacing eager O(n²) pair generation with streaming pair generation.

Problem

Training with contrastive loss (e.g., CosineSimilarityLoss) on datasets with a large enough number of samples causes OOM before training even starts. The root cause is three layers of O(n²) memory allocation:

  1. shuffle_combinations() creates all pair indices upfront
  2. ContrastiveDataset stores all pairs in pos_pairs/neg_pairs lists
  3. Dataset.from_list(list(...)) materializes the iterator again

Solution

sampler.py:

  • Replace shuffle_combinations() with on-the-fly random pair sampling
  • ContrastiveDataset now stores only label_to_indices mapping (O(n))
  • __iter__() generates pairs on-the-fly with set-based uniqueness tracking
  • Same changes for ContrastiveDistillationDataset

trainer.py / trainer_distillation.py:

  • Use IterableDataset.from_generator() instead of Dataset.from_list(list(...))
  • Compute max_steps automatically for IterableDataset compatibility

Memory Comparison

Component Before After
Index arrays O(n²) 0
Pair lists O(n²) 0
Label grouping 0 O(n)
Uniqueness set 0 O(num_pairs)
Dataset copy O(n²) 0

Breaking Changes

  • ContrastiveDataset.pos_pairs and neg_pairs attributes removed
  • ContrastiveDataset.len_pos_pairs / len_neg_pairs now represent target counts, not stored counts
  • Added estimated_num_pairs property for logging

Testing

  • Verified training completes on large datasets that previously OOM'd
  • Pairs maintain uniqueness via set-based tracking
  • Reproducibility preserved via seeded RNG

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant