diff --git a/src/mistral_inference/lora.py b/src/mistral_inference/lora.py index 30924290..ce4afd17 100644 --- a/src/mistral_inference/lora.py +++ b/src/mistral_inference/lora.py @@ -69,8 +69,9 @@ def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None: self.register_load_state_dict_post_hook(ignore_missing_keys) def forward(self, x: torch.Tensor) -> torch.Tensor: - lora = self.lora_B(self.lora_A(x)) - result: torch.Tensor = self.linear(x) + lora * self.scaling + lora_A_result = self.lora_A(x) + lora = self.lora_B(lora_A_result) + result = self.linear(x) + lora * self.scaling return result def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def]