From b581721b95de194d30a914108520b543a82b11eb Mon Sep 17 00:00:00 2001 From: rvorias Date: Sun, 28 May 2023 17:13:31 +0200 Subject: [PATCH] Add flag to disable LoRA monkey-patching --- src/diffusers/loaders.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index beb1c380d6cb..3514431716a7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -808,11 +808,13 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. - mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. + text_encoder_modify_forwards(`bool`, *optional*, defaults to `True`): + Whether or not to monkey-patch the forward pass of the text encoder to use the LoRA layers. + Monkey-patching should only happen once, so set this flag to False if you call this function more than once. @@ -833,6 +835,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + text_encoder_modify_forwards = kwargs.pop("text_encoder_modify_forwards", True) + if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -921,7 +925,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di attn_procs_text_encoder = self._load_text_encoder_attn_procs( text_encoder_lora_state_dict, network_alpha=network_alpha ) - self._modify_text_encoder(attn_procs_text_encoder) + if text_encoder_modify_forwards: + self._modify_text_encoder(attn_procs_text_encoder) # save lora attn procs of text encoder so that it can be easily retrieved self._text_encoder_lora_attn_procs = attn_procs_text_encoder