From dd8bf462e7536c9f086985a1bb3b20c44fea8cb8 Mon Sep 17 00:00:00 2001 From: Mandlin Sarah Date: Fri, 9 Aug 2024 16:23:15 -0700 Subject: [PATCH] Optimized linear layer operations for efficiency --- src/mistral_inference/lora.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mistral_inference/lora.py b/src/mistral_inference/lora.py index 30924290..ce4afd17 100644 --- a/src/mistral_inference/lora.py +++ b/src/mistral_inference/lora.py @@ -69,8 +69,9 @@ def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None: self.register_load_state_dict_post_hook(ignore_missing_keys) def forward(self, x: torch.Tensor) -> torch.Tensor: - lora = self.lora_B(self.lora_A(x)) - result: torch.Tensor = self.linear(x) + lora * self.scaling + lora_A_result = self.lora_A(x) + lora = self.lora_B(lora_A_result) + result = self.linear(x) + lora * self.scaling return result def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def]