Add Matryoshka Representation Learning (MRL) support and tests #39
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #8
This pull request introduces Matryoshka Representation Learning (MRL) to the codebase, enabling models to produce multi-granular embeddings that can be flexibly truncated to various dimensions for downstream tasks. It adds configuration options, utility functions, and a comprehensive test suite for MRL, and integrates the MRL loss into both training and validation workflows.
Matryoshka Representation Learning (MRL) Integration:
use_mrl,mrl_dimensions, andmrl_temperatureparameters in the config and arguments. [1] [2] [3]matryoshka_lossfunction inutils.pyto compute the MRL loss, which encourages high-quality embeddings at multiple dimensions.Configuration and Argument Enhancements:
use_mrl,mrl_dimensions,mrl_temperature) to theArgsclass and provided an example configuration fileconfig_mrl.json. [1] [2]Training and Validation Pipeline Updates:
inbatch_lossandhard_lossfunctions, and updated the training (accelerate_train) and validation (validate) routines to use the new MRL parameters when enabled. [1] [2] [3] [4] [5] [6]Testing and Documentation:
test_mrl.pyto validate MRL loss computation, integration with existing losses, and embedding truncation behavior.README.mdwith an explanation of MRL, configuration steps, and quick validation instructions.