Fix OOM in contrastive pair generation with streaming approach #627
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes out-of-memory issues when training SetFit on large datasets by replacing eager O(n²) pair generation with streaming pair generation.
Problem
Training with contrastive loss (e.g.,
CosineSimilarityLoss) on datasets with a large enough number of samples causes OOM before training even starts. The root cause is three layers of O(n²) memory allocation:shuffle_combinations()creates all pair indices upfrontContrastiveDatasetstores all pairs inpos_pairs/neg_pairslistsDataset.from_list(list(...))materializes the iterator againSolution
sampler.py:shuffle_combinations()with on-the-fly random pair samplingContrastiveDatasetnow stores onlylabel_to_indicesmapping (O(n))__iter__()generates pairs on-the-fly with set-based uniqueness trackingContrastiveDistillationDatasettrainer.py/trainer_distillation.py:IterableDataset.from_generator()instead ofDataset.from_list(list(...))max_stepsautomatically for IterableDataset compatibilityMemory Comparison
Breaking Changes
ContrastiveDataset.pos_pairsandneg_pairsattributes removedContrastiveDataset.len_pos_pairs/len_neg_pairsnow represent target counts, not stored countsestimated_num_pairsproperty for loggingTesting