Skip to content

Conversation

@vidixha
Copy link

@vidixha vidixha commented Aug 26, 2025

This PR adds support for using the gemma3-siglip-encoder (from google/gemma-3-4b-it) as a vision tower for LLaVA pre-training with a Vicuna-based LLM.

1. Numerical Instability issue - NaN loss

Initial attempts to pre-train using the gemma3-siglip-encoder resulted in a persistent NaN loss. Debugging revealed that the encoder produces feature outputs with an extremely large numerical magnitude. This triggered a low-level bug deep inside the language model's CrossEntropyLoss function, causing it to fail even when all inputs (logits and labels) were valid.

2. Implementation Details
To enable stable training, the following two-part solution was implemented:

  • Feature Clipping: A torch.clamp function was added to the encode_images method in llava/model/llava_arch.py. This controls the extreme magnitude of the gemma features by ensuring they are within a stable [-10, 10] range before being passed to the language model.

  • Manual Loss Calculation: The compute_loss method in llava/train/llava_trainer.py was overridden to bypass the model's unstable internal loss function. This implementation takes the clean logits from the model and performs a stable, manual CrossEntropyLoss calculation.


# ====================================================================================
# ================== FIX 2: Brute-force clip the features ============================
image_features = torch.clamp(image_features, min=-10.0, max=10.0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to do this only for siglip from gemma3 to make sure others are unchanged?

Copy link
Author

@vidixha vidixha Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i will add a check to do it only for vision towers from Gemma


class LLaVATrainer(Trainer):

def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


class LLaVATrainer(Trainer):

def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also make sure this compute_loss is bypassed in case of other encoders. I am still not sure why exactly we need a custom loss here. Is it to handle sequence mismatch problem? If so why is it not a problem in other encoders? Or is there any other reason? So unless those are clear, lets make sure we do this custom compute_loss only for gemma3-siglip

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in your if 'siglip' in self.data_args.image_processor.image_processor_type.lower(): line, you need to update to

if 'siglip' or 'gemma' in self.data_args.image_processor.image_processor_type.lower():

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants