-
Notifications
You must be signed in to change notification settings - Fork 15
support for encoder from gemma-3-4b-it #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
llava/model/llava_arch.py
Outdated
|
|
||
| # ==================================================================================== | ||
| # ================== FIX 2: Brute-force clip the features ============================ | ||
| image_features = torch.clamp(image_features, min=-10.0, max=10.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to do this only for siglip from gemma3 to make sure others are unchanged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i will add a check to do it only for vision towers from Gemma
|
|
||
| class LLaVATrainer(Trainer): | ||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for reference, this is how compute_loss looks like in HF trainer class https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618
|
|
||
| class LLaVATrainer(Trainer): | ||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also make sure this compute_loss is bypassed in case of other encoders. I am still not sure why exactly we need a custom loss here. Is it to handle sequence mismatch problem? If so why is it not a problem in other encoders? Or is there any other reason? So unless those are clear, lets make sure we do this custom compute_loss only for gemma3-siglip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in your if 'siglip' in self.data_args.image_processor.image_processor_type.lower(): line, you need to update to
if 'siglip' or 'gemma' in self.data_args.image_processor.image_processor_type.lower():
This PR adds support for using the
gemma3-siglip-encoder(fromgoogle/gemma-3-4b-it) as a vision tower for LLaVA pre-training with a Vicuna-based LLM.1. Numerical Instability issue - NaN loss
Initial attempts to pre-train using the
gemma3-siglip-encoderresulted in a persistentNaNloss. Debugging revealed that the encoder produces feature outputs with an extremely large numerical magnitude. This triggered a low-level bug deep inside the language model'sCrossEntropyLossfunction, causing it to fail even when all inputs (logitsandlabels) were valid.2. Implementation Details
To enable stable training, the following two-part solution was implemented:
Feature Clipping: A
torch.clampfunction was added to theencode_imagesmethod inllava/model/llava_arch.py. This controls the extreme magnitude of thegemmafeatures by ensuring they are within a stable[-10, 10]range before being passed to the language model.Manual Loss Calculation: The
compute_lossmethod inllava/train/llava_trainer.pywas overridden to bypass the model's unstable internal loss function. This implementation takes the cleanlogitsfrom the model and performs a stable, manualCrossEntropyLosscalculation.