From 724b297f9033a3e552fba2ce85e5dadd9ff02873 Mon Sep 17 00:00:00 2001 From: Mandlin Sarah Date: Thu, 8 Aug 2024 19:42:45 -0700 Subject: [PATCH] plaintext Improve dtype assertion in LoRALoaderMixin for clarity --- src/mistral_inference/lora.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/mistral_inference/lora.py b/src/mistral_inference/lora.py index 30924290..e019b394 100644 --- a/src/mistral_inference/lora.py +++ b/src/mistral_inference/lora.py @@ -102,12 +102,15 @@ def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None: def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None: """Loads LoRA state_dict""" - lora_dtypes = set([p.dtype for p in lora_state_dict.values()]) - assert ( - len(lora_dtypes) == 1 - ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype" - lora_dtype = lora_dtypes.pop() - assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" # type: ignore[attr-defined] + lora_dtypes = set(p.dtype for p in lora_state_dict.values()) + assert len(lora_dtypes) == 1, ( + f"LoRA weights have multiple different dtypes {lora_dtypes}. " + "All weights need to have the same dtype" + ) + lora_dtype = next(iter(lora_dtypes)) + assert lora_dtype == self.dtype, ( + f"LoRA weights dtype ({lora_dtype}) differs from model's dtype ({self.dtype})" + ) # type: ignore[attr-defined] assert all("lora" in key for key in lora_state_dict.keys()) # move tensors to device @@ -152,4 +155,4 @@ def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scalin self.pipeline_rank, # type: ignore[attr-defined] ) - self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined] + self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined] \ No newline at end of file