From 4faccdaa0de691ea67aae0ae3d48bb5406bda260 Mon Sep 17 00:00:00 2001 From: yjjinjie Date: Fri, 20 Oct 2023 16:06:01 +0800 Subject: [PATCH 1/3] add oneflow to acc infer --- Dockerfile | 63 ++++++++++++++++++++++++++++++++++++ README.md | 30 +++++++++++++++-- app.py | 13 ++++++++ easyphoto/easyphoto_infer.py | 41 ++++++++++++++++++++--- easyphoto/sd_diffusers.py | 24 +++++++++++--- 5 files changed, 160 insertions(+), 11 deletions(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c7682b0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,63 @@ + +FROM nvidia/cuda:11.7.1-devel-ubuntu22.04 + +ARG DEBIAN_FRONTEND=noninteractive +RUN rm -rf /etc/localtime && ln -s /usr/share/zoneinfo/Asia/Harbin /etc/localtime +ENV PYTHONUNBUFFERED 1 + +RUN apt-get update -y && apt-get install -y build-essential apt-utils google-perftools sox \ + ffmpeg libcairo2 libcairo2-dev libcairo2-dev zip wget curl vim git ca-certificates kmod \ + python3-pip python-is-python3 python3.10-venv aria2 && rm -rf /var/lib/apt/lists/* + +RUN pip install numpy==1.23.5 Pillow==9.5.0 mpmath>=0.19 networkx sympy -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu117 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install xformers==0.0.20 mediapipe manimlib svglib fvcore ffmpeg modelscope ultralytics albumentations==0.4.3 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install voluptuous toml accelerate>=0.20.3 lion-pytorch chardet lxml pathos cryptography openai boto3 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install aliyun-python-sdk-core aliyun-python-sdk-alimt insightface==0.7.3 onnx==1.14.0 dadaptation PyExecJS -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install pims gradio==3.32.0 setuptools>=42 blendmodes==2022 basicsr==1.4.2 gfpgan==1.3.8 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install realesrgan==0.3.0 omegaconf==2.2.3 pytorch_lightning==1.9.4 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install scikit-image timm==0.6.7 piexif==1.1.3 einops psutil==5.9.5 jsonmerge==1.8.0 --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install clean-fid==0.1.35 resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 segment_anything supervision -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install lark==1.1.2 inflection==0.5.1 GitPython==3.1.30 safetensors==0.3.1 fairscale numba==0.57.0 moviepy==1.0.2 transforms3d==0.4.1 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install httpcore==0.15 fastapi==0.94.0 tomesd==0.1.2 numexpr matplotlib pandas av wandb appdirs lpips dataclasses pyqt6 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install imageio-ffmpeg==0.4.2 rich gdown onnxruntime==1.15.0 ifnude pycocoevalcap clip-anytorch sentencepiece tokenizers==0.13.3 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install transformers==4.25.1 trampoline==0.1.2 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install transparent-background ipython seaborn color_matcher trimesh vispy>=0.13.0 rembg>=2.0.50 py-cpuinfo protobuf -i https://mirrors.aliyun.com/pypi/simple/ + +RUN wget https://pai-aigc-extension.oss-cn-hangzhou.aliyuncs.com/torchsde.zip -O /tmp/torchsde.zip && \ + cd /tmp && unzip torchsde.zip && cd torchsde && python3 setup.py install && rm -rf /tmp/torchsde* + +RUN pip install diffusers==0.18.2 segmentation-refinement send2trash~=1.8 dynamicprompts[attentiongrabber,magicprompt]~=0.29.0 gradio_client==0.2.7 -i https://mirrors.aliyun.com/pypi/simple/ +RUN pip install opencv-python onnx onnxruntime modelscope + + +# download more sdwebui requirements +RUN wget https://pai-vision-data-sh.oss-cn-shanghai.aliyuncs.com/aigc-data/easyphoto/requirements_versions.txt +RUN pip install -r requirements_versions.txt -i https://mirrors.aliyun.com/pypi/simple/ + +# download openai +RUN mkdir -p /root/.cache/ +RUN curl -o /root/.cache/huggingface.zip https://pai-vision-data-sh.oss-cn-shanghai.aliyuncs.com/aigc-data/easyphoto/huggingface.zip +RUN unzip /root/.cache/huggingface.zip -d /root/.cache/ + +RUN pip install controlnet_aux + + +# torch model to replace tensorflow model,need install mmcv +# https://mmdetection.readthedocs.io/en/v2.9.0/faq.html +# https://github.com/open-mmlab/mmdetection/issues/6765 +RUN MMCV_WITH_OPS=1 FORCE_CUDA=1 MMCV_CUDA_ARGS='-gencode=arch=compute_80,code=sm_80' pip install mmcv-full==1.7.0 --index https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip install mmdet==2.28.2 --index https://pypi.tuna.tsinghua.edu.cn/simple + +# oneflow release has bug,we use 2023.10.20 version +RUN pip uninstall tensorflow tensorflow-cpu xformers accelerate -y +RUN pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu117 --index https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip install "transformers==4.27.1" "diffusers[torch]==0.19.3" --index https://pypi.tuna.tsinghua.edu.cn/simple +# onediff release has bug,we use 2023.10.20 version +RUN wget -P /root/.cache http://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/onediff.zip +RUN unzip /root/.cache/onediff.zip -d /root/.cache +RUN cd /root/.cache/onediff && python3 -m pip install -e . + +WORKDIR /workspace + +RUN pip cache purge \ No newline at end of file diff --git a/README.md b/README.md index 8152248..312c7c0 100644 --- a/README.md +++ b/README.md @@ -83,17 +83,43 @@ pip install -r requirements.txt # launch tool python app.py +``` +##### Advanced(Optional) +use oneflow to accelerate infer,we give the cu117 install,you can replace to cpu. oneflow may has some import warning,just skip it +``` +# torch model to replace tensorflow model,need install mmcv + # https://mmdetection.readthedocs.io/en/v2.9.0/faq.html + # https://github.com/open-mmlab/mmdetection/issues/6765 + +MMCV_WITH_OPS=1 FORCE_CUDA=1 MMCV_CUDA_ARGS='-gencode=arch=compute_80,code=sm_80' pip install mmcv-full==1.7.0 --index https://pypi.tuna.tsinghua.edu.cn/simple +pip install mmdet==2.28.2 --index https://pypi.tuna.tsinghua.edu.cn/simple + +# oneflow release has bug,we use 2023.10.20 version +# https://github.com/Oneflow-Inc/diffusers + +pip uninstall tensorflow tensorflow-cpu xformers accelerate -y +pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu117 --index https://pypi.tuna.tsinghua.edu.cn/simple +pip install "transformers==4.27.1" "diffusers[torch]==0.19.3" --index https://pypi.tuna.tsinghua.edu.cn/simple + +# onediff release has bug,we use 2023.10.20 version +wget -P /root/.cache http://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/onediff.zip +unzip /root/.cache/onediff.zip -d /root/.cache +cd /root/.cache/onediff && python3 -m pip install -e . + ``` ### 2. Build from Docker ``` # pull image -docker pull registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto-diffusers-py310-torch201-cu117 +docker pull registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto-diffusers-py310-torch201-cu117-oneflow + +git clone https://github.com/aigc-apps/EasyPhoto.git # enter image -docker run -it --network host --gpus all registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto-diffusers-py310-torch201-cu117 +docker run -it --network host -v $(pwd):/paiya_acc --gpus all registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto-diffusers-py310-torch201-cu117 +cd /paiya_acc/EasyPhoto # launch python app.py ``` diff --git a/app.py b/app.py index b6a1886..f247403 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,18 @@ import time import argparse +import logging +import os +try: + # becase the onediff has warning,so we import oneflow to judge + import oneflow + os.environ['use_oneflow'] = '1' +except: + logging.warning('No module named oneflow. Infer without using oneflow. You can read https://github.com/Oneflow-Inc/diffusers.') + +# it must be imported before diffusers +if os.environ.get('use_oneflow'): + from onediff.infer_compiler import oneflow_compile + from easyphoto.easyphoto_ui import on_ui_tabs from easyphoto.easyphoto_utils import reload_javascript from easyphoto.api import easyphoto_infer_forward_api, easyphoto_train_forward_api diff --git a/easyphoto/easyphoto_infer.py b/easyphoto/easyphoto_infer.py index 5ea1398..acee35f 100644 --- a/easyphoto/easyphoto_infer.py +++ b/easyphoto/easyphoto_infer.py @@ -1,8 +1,9 @@ import copy import glob import logging -import os +import os +import time import cv2 import numpy as np import torch @@ -209,6 +210,7 @@ def easyphoto_infer_forward( seed, crop_face_preprocess, apply_face_fusion_before, apply_face_fusion_after, color_shift_middle, color_shift_last, super_resolution, display_score, \ background_restore, background_restore_denoising_strength, sd_xl_input_prompt, sd_xl_resolution, tabs, *user_ids, ): + s1 = time.time() # global global retinaface_detection, image_face_fusion, skin_retouching, portrait_enhancement, face_skin, face_recognition, check_hash @@ -252,7 +254,11 @@ def easyphoto_infer_forward( if retinaface_detection is None: retinaface_detection = pipeline(Tasks.face_detection, 'damo/cv_resnet50_face-detection_retinaface', model_revision='v2.0.2') if image_face_fusion is None: - image_face_fusion = pipeline(Tasks.image_face_fusion, model='damo/cv_unet-image-face-fusion_damo', model_revision='v1.3') + # oneflow cannot use tensorflow,use torch model + if os.environ.get('use_oneflow'): + image_face_fusion = pipeline('face_fusion_torch', model='damo/cv_unet_face_fusion_torch', model_revision='v1.0.3') + else: + image_face_fusion = pipeline(Tasks.image_face_fusion, model='damo/cv_unet-image-face-fusion_damo', model_revision='v1.3') if face_skin is None: face_skin = Face_Skin(os.path.join(models_path, "Others", "face_skin.pth")) if skin_retouching is None: @@ -454,6 +460,7 @@ def easyphoto_infer_forward( input_image = input_image.resize([new_width, new_height], Image.Resampling.LANCZOS) # Detect the box where the face of the template image is located and obtain its corresponding small mask + tmp=time.time() logging.info("Start face detect.") input_image_retinaface_boxes, input_image_retinaface_keypoints, input_masks = call_face_crop(retinaface_detection, input_image, 1.05, "template") input_image_retinaface_box = input_image_retinaface_boxes[0] @@ -499,15 +506,19 @@ def easyphoto_infer_forward( # here we get the retinaface_box, we should use this Input box and face pixel to refine the output face pixel colors template_image_original_face_area = np.array(original_input_template)[input_image_retinaface_box[1]:input_image_retinaface_box[3], input_image_retinaface_box[0]:input_image_retinaface_box[2], :] - + print("End face detect. ",time.time()-tmp) + # First diffusion, facial reconstruction logging.info("Start First diffusion.") + tmp=time.time() controlnet_pairs = [["canny", input_image, 0.50], ["openpose", replaced_input_image, 0.50], ["color", input_image, 0.85]] first_diffusion_output_image = inpaint(input_image, input_mask, controlnet_pairs, diffusion_steps=first_diffusion_steps, denoising_strength=first_denoising_strength, input_prompt=input_prompts[index], hr_scale=1.0, seed=str(seed), sd_model_checkpoint=sd_model_checkpoint, sd_lora_checkpoint=sd_lora_checkpoints[index]) + print("End First diffusion.",time.time() -tmp) if color_shift_middle: # apply color shift logging.info("Start color shift middle.") + tmp=time.time() first_diffusion_output_image_uint8 = np.uint8(np.array(first_diffusion_output_image)) # crop image first first_diffusion_output_image_crop = Image.fromarray(first_diffusion_output_image_uint8[input_image_retinaface_box[1]:input_image_retinaface_box[3], input_image_retinaface_box[0]:input_image_retinaface_box[2],:]) @@ -522,10 +533,12 @@ def easyphoto_infer_forward( first_diffusion_output_image_uint8[input_image_retinaface_box[1]:input_image_retinaface_box[3], input_image_retinaface_box[0]:input_image_retinaface_box[2],:] = \ first_diffusion_output_image_crop_color_shift * face_skin_mask + np.array(first_diffusion_output_image_crop) * (1 - face_skin_mask) first_diffusion_output_image = Image.fromarray(np.uint8(first_diffusion_output_image_uint8)) + print("End color shift middle. ",time.time() -tmp) # Second diffusion if roop_images[index] is not None and apply_face_fusion_after: # Fusion of facial photos with user photos + tmp=time.time() logging.info("Start second face fusion.") fusion_image = image_face_fusion(dict(template=first_diffusion_output_image, user=roop_images[index]))[OutputKeys.OUTPUT_IMG] # swap_face(target_img=output_image, source_img=roop_image, model="inswapper_128.onnx", upscale_options=UpscaleOptions()) fusion_image = Image.fromarray(cv2.cvtColor(fusion_image, cv2.COLOR_BGR2RGB)) @@ -545,8 +558,11 @@ def easyphoto_infer_forward( fusion_image = first_diffusion_output_image input_image = first_diffusion_output_image + print("End Second face fusion.. ",time.time() -tmp) + # Add mouth_mask to avoid some fault lips, close if you dont need if need_mouth_fix: + tmp=time.time() logging.info("Start mouth detect.") mouth_mask, face_mask = face_skin(input_image, retinaface_detection, [[4, 5, 12, 13], [1, 2, 3, 4, 5, 10, 12, 13]]) # Obtain the mask of the area around the face @@ -557,14 +573,18 @@ def easyphoto_infer_forward( if i_h != m_h or i_w != m_w: face_mask = face_mask.resize([m_w, m_h]) input_mask = Image.fromarray(np.uint8(np.clip(np.float32(face_mask) + np.float32(mouth_mask), 0, 255))) - + print("End mouth detect. ",time.time() -tmp) + logging.info("Start Second diffusion.") + tmp=time.time() controlnet_pairs = [["canny", fusion_image, 1.00], ["tile", fusion_image, 1.00]] second_diffusion_output_image = inpaint(input_image, input_mask, controlnet_pairs, input_prompts[index], diffusion_steps=second_diffusion_steps, denoising_strength=second_denoising_strength, hr_scale=default_hr_scale, seed=str(seed), sd_model_checkpoint=sd_model_checkpoint, sd_lora_checkpoint=sd_lora_checkpoints[index]) + print("End Second diffusion.. ",time.time() -tmp) # use original template face area to shift generated face color at last if color_shift_last: logging.info("Start color shift last.") + tmp=time.time() # scale box rescale_retinaface_box = [int(i * default_hr_scale) for i in input_image_retinaface_box] second_diffusion_output_image_uint8 = np.uint8(np.array(second_diffusion_output_image)) @@ -580,10 +600,12 @@ def easyphoto_infer_forward( second_diffusion_output_image_uint8[rescale_retinaface_box[1]:rescale_retinaface_box[3], rescale_retinaface_box[0]:rescale_retinaface_box[2],:] = \ second_diffusion_output_image_crop_color_shift * face_skin_mask + np.array(second_diffusion_output_image_crop) * (1 - face_skin_mask) second_diffusion_output_image = Image.fromarray(second_diffusion_output_image_uint8) + print("End color shift last. ",time.time() -tmp) # If it is a large template for cutting, paste the reconstructed image back if crop_face_preprocess: logging.info("Start paste crop image to origin template.") + tmp=time.time() origin_loop_template_image = np.array(copy.deepcopy(loop_template_image)) x1,y1,x2,y2 = loop_template_crop_safe_box @@ -591,6 +613,7 @@ def easyphoto_infer_forward( origin_loop_template_image[y1:y2,x1:x2] = np.array(second_diffusion_output_image) loop_output_image = Image.fromarray(np.uint8(origin_loop_template_image)) + print("End paste crop image to origin template.",time.time() -tmp) else: loop_output_image = second_diffusion_output_image @@ -618,6 +641,7 @@ def easyphoto_infer_forward( try: if min(len(template_face_safe_boxes), len(user_ids) - len(passed_userid_list)) > 1 or background_restore: + tmp=time.time() logging.info("Start Thirt diffusion for background.") output_image = Image.fromarray(np.uint8(output_image)) short_side = min(output_image.width, output_image.height) @@ -632,6 +656,7 @@ def easyphoto_infer_forward( denoising_strength = background_restore_denoising_strength if background_restore else 0.3 controlnet_pairs = [["canny", output_image, 1.00], ["color", output_image, 1.00]] output_image = inpaint(output_image, output_mask, controlnet_pairs, input_prompt_without_lora, 30, denoising_strength=denoising_strength, hr_scale=1, seed=str(seed), sd_model_checkpoint=sd_model_checkpoint) + print("End Third diffusion.. ",time.time() -tmp) except Exception as e: torch.cuda.empty_cache() logging.error(f"Background Restore Failed, Please check the ratio of height and width in template. Error Info: {e}") @@ -640,17 +665,21 @@ def easyphoto_infer_forward( try: logging.info("Start Skin Retouching.") # Skin Retouching is performed here. + tmp=time.time() output_image = Image.fromarray(cv2.cvtColor(skin_retouching(output_image)[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB)) + print("End Skin Retouching.",time.time() -tmp) except Exception as e: torch.cuda.empty_cache() logging.error(f"Skin Retouching error: {e}") try: logging.info("Start Portrait enhancement.") + tmp=time.time() h, w, c = np.shape(np.array(output_image)) # Super-resolution is performed here. if super_resolution: output_image = Image.fromarray(cv2.cvtColor(portrait_enhancement(output_image)[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB)) + print("End Portrait enhancement.",time.time() -tmp) except Exception as e: torch.cuda.empty_cache() logging.error(f"Portrait enhancement error: {e}") @@ -666,11 +695,13 @@ def easyphoto_infer_forward( loop_message += f"Template {str(template_idx + 1)} Success." except Exception as e: torch.cuda.empty_cache() - logging.error(f"Template {str(template_idx + 1)} error: Error info is {e}, skip it.") + logging.exception(f"Template {str(template_idx + 1)} error: Error info is {e}, skip it.") if loop_message != "": loop_message += "\n" loop_message += f"Template {str(template_idx + 1)} error: Error info is {e}." torch.cuda.empty_cache() + t = time.time()-s1 + print("all processor time:",t) return "Success", outputs, face_id_outputs \ No newline at end of file diff --git a/easyphoto/sd_diffusers.py b/easyphoto/sd_diffusers.py index 21dfb7b..856ed52 100644 --- a/easyphoto/sd_diffusers.py +++ b/easyphoto/sd_diffusers.py @@ -1,8 +1,7 @@ import logging import os import re -from collections import defaultdict - +from collections import defaultdict import torch import torch.utils.checkpoint from diffusers import (DPMSolverMultistepScheduler, @@ -20,6 +19,7 @@ vae = None unet = None pipeline = None +oneflow_unet = None sd_model_checkpoint_before = "" weight_dtype = torch.float16 SCHEDULER_LINEAR_START = 0.00085 @@ -215,7 +215,7 @@ def i2i_inpaint_call( sd_model_checkpoint="", sd_base15_checkpoint="", ): - global tokenizer, scheduler, text_encoder, vae, unet, sd_model_checkpoint_before, pipeline + global tokenizer, scheduler, text_encoder, vae, unet, sd_model_checkpoint_before, pipeline, oneflow_unet width = int(width // 8 * 8) height = int(height // 8 * 8) @@ -229,6 +229,7 @@ def i2i_inpaint_call( sd_base15_checkpoint, subfolder="tokenizer" ) + pipeline = StableDiffusionControlNetInpaintPipeline( controlnet=controlnet_units_list, unet=unet.to(weight_dtype), @@ -239,6 +240,7 @@ def i2i_inpaint_call( safety_checker=None, feature_extractor=None, ).to("cuda") + if preload_lora is not None: for _preload_lora in preload_lora: merge_lora(pipeline, _preload_lora, 0.60, from_safetensor=True, device="cuda", dtype=weight_dtype) @@ -252,7 +254,17 @@ def i2i_inpaint_call( pipeline.enable_xformers_memory_efficient_attention() except: logging.warning('No module named xformers. Infer without using xformers. You can run pip install xformers to install it.') - + + if os.environ.get('use_oneflow') and oneflow_unet is None: + print("unet compile begin") + from onediff.infer_compiler import oneflow_compile + oneflow_unet = oneflow_compile(pipeline.unet) + print("unet compile compelete") + + if oneflow_unet: + print("use oneflow to infer") + pipeline.unet = oneflow_unet + generator = torch.Generator("cuda").manual_seed(int(seed)) pipeline.safety_checker = None @@ -261,6 +273,9 @@ def i2i_inpaint_call( guidance_scale=cfg_scale, num_inference_steps=steps, generator=generator, height=height, width=width, \ controlnet_conditioning_scale=controlnet_conditioning_scale, guess_mode=True ).images[0] + + if oneflow_unet: + pipeline.unet = unet if len(sd_lora_checkpoint) != 0: # Bind LoRANetwork to pipeline. @@ -269,4 +284,5 @@ def i2i_inpaint_call( if preload_lora is not None: for _preload_lora in preload_lora: unmerge_lora(pipeline, _preload_lora, 0.60, from_safetensor=True, device="cuda", dtype=weight_dtype) + return image \ No newline at end of file From bf2f981c7ea4521db8d4dd201a5484ae6e9583b0 Mon Sep 17 00:00:00 2001 From: yjjinjie Date: Fri, 20 Oct 2023 17:18:10 +0800 Subject: [PATCH 2/3] add oneflow to acc infer --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 312c7c0..47618d9 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ docker pull registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto- git clone https://github.com/aigc-apps/EasyPhoto.git # enter image -docker run -it --network host -v $(pwd):/paiya_acc --gpus all registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto-diffusers-py310-torch201-cu117 +docker run -it --network host -v $(pwd):/paiya_acc --gpus all registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easyphoto-diffusers-py310-torch201-cu117-oneflow cd /paiya_acc/EasyPhoto # launch From 4b39cd18fd01ac1c2a176b91a5eb5ecc8a5c9472 Mon Sep 17 00:00:00 2001 From: yjjinjie Date: Thu, 26 Oct 2023 17:47:02 +0800 Subject: [PATCH 3/3] add channel last for unet --- easyphoto/sd_diffusers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/easyphoto/sd_diffusers.py b/easyphoto/sd_diffusers.py index 856ed52..a77e133 100644 --- a/easyphoto/sd_diffusers.py +++ b/easyphoto/sd_diffusers.py @@ -51,7 +51,6 @@ def merge_lora(pipeline, lora_path, multiplier, from_safetensor=False, device='c updates[layer][elem] = value for layer, elems in updates.items(): - if "text" in layer: layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") curr_layer = pipeline.text_encoder @@ -83,6 +82,7 @@ def merge_lora(pipeline, lora_path, multiplier, from_safetensor=False, device='c alpha = 1.0 curr_layer.weight.data = curr_layer.weight.data.to(device) + if len(weight_up.shape) == 4: curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze( @@ -119,7 +119,6 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, from_safetensor=False, devic updates[layer][elem] = value for layer, elems in updates.items(): - if "text" in layer: layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") curr_layer = pipeline.text_encoder @@ -156,7 +155,6 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, from_safetensor=False, devic weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) else: curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) - return pipeline def t2i_sdxl_call( @@ -241,6 +239,14 @@ def i2i_inpaint_call( feature_extractor=None, ).to("cuda") + pipeline.unet.to(memory_format=torch.channels_last) + + if os.environ.get('use_oneflow') and oneflow_unet is None: + print("unet compile begin") + from onediff.infer_compiler import oneflow_compile + oneflow_unet = oneflow_compile(pipeline.unet) + print("unet compile compelete") + if preload_lora is not None: for _preload_lora in preload_lora: merge_lora(pipeline, _preload_lora, 0.60, from_safetensor=True, device="cuda", dtype=weight_dtype) @@ -248,18 +254,12 @@ def i2i_inpaint_call( # Bind LoRANetwork to pipeline. for _sd_lora_checkpoint in sd_lora_checkpoint: merge_lora(pipeline, _sd_lora_checkpoint, 0.90, from_safetensor=True, device="cuda", dtype=weight_dtype) - try: import xformers pipeline.enable_xformers_memory_efficient_attention() except: logging.warning('No module named xformers. Infer without using xformers. You can run pip install xformers to install it.') - if os.environ.get('use_oneflow') and oneflow_unet is None: - print("unet compile begin") - from onediff.infer_compiler import oneflow_compile - oneflow_unet = oneflow_compile(pipeline.unet) - print("unet compile compelete") if oneflow_unet: print("use oneflow to infer")