-
Notifications
You must be signed in to change notification settings - Fork 95
Description
Hello, and thank you for maintaining the SPLADE repository.
While working with the codebase, I noticed two issues that appear to affect training behavior. I would like to report them here in case they require correction.
1. Typo in SiameseTransformerTrainer.forward() — augment_pairs vs. augment_pair
- File:
splade/tasks/transformer_trainer.py - Class:
SiameseTransformerTrainer(TransformerTrainer)
Inside the forward function, the following conditional is used:
if "augment_pairs" in self.config:
if self.config["augment_pairs"] == "in_batch_negatives":
However, the configuration files use the key augment_pair (without s).
This mismatch prevents the augmentation logic from being triggered as intended.
It seems the correct key should likely be:
if "augment_pair" in self.config:
if self.config["augment_pair"] == "in_batch_negatives":
2. Incorrect assignment of self.fn in early_stopping.py due to missing parentheses
- File:
splade/tasks/base/early_stopping.py - Class: class EarlyStopping`
When the key is not "loss" (e.g., using "MRR@10"), the comparator always returns a function object (which is Truthy) instead of a boolean, causing the check to always pass.
The current implementation is:
self.fn = lambda x, y: x < y if mode == "loss" else lambda a, b: a > b
Due to operator precedence, when
mode != "loss", this expression returns theinner lambda function object itself(lambda a, b: a > b), rather than executing it.
Since function objects in Python are Truthy, the condition if self.fn(val_perf, self.best) :always evaluates to True.
Impact
- Early Stopping never triggers: The counter is reset to 0 at every validation step because the condition is always effectively True.
- "Best" Model Overwrite: Every checkpoint is saved as the "best" model, regardless of whether performance actually improved. The final saved model is simply the last checkpoint, not the best one.
example:
python
mode = "metric"
fn = lambda x, y: x < y if mode == "loss" else lambda a, b: a > b
result = fn(0.5, 0.4)
print(f"Type: {type(result)}")
# Output: <class 'function'> -> This effectively evaluates to True in if-statements
Parentheses should be used to enforce correct precedence:
self.fn = (lambda x, y: x < y) if mode == "loss" else (lambda a, b: a > b)