diff --git a/__init__.py b/__init__.py index a03726e..c66514e 100644 --- a/__init__.py +++ b/__init__.py @@ -7,3 +7,6 @@ from .nodes import NODE_CLASS_MAPPINGS NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()} __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] + + +WEB_DIRECTORY = "./js" \ No newline at end of file diff --git a/js/clip_loader.js b/js/clip_loader.js new file mode 100644 index 0000000..04cdc2c --- /dev/null +++ b/js/clip_loader.js @@ -0,0 +1,33 @@ +import { app } from "../../scripts/app.js" + +app.registerExtension({ + name: "ComfyUI_GGUF.clip_loader_gguf", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "CLIPLoaderGGUF") { + const onNodeCreated = nodeType.prototype.onNodeCreated; + + nodeType.prototype.onNodeCreated = function () { + onNodeCreated?.apply(this, arguments); + const node = this; + + const typeWidget = node.widgets.find((w) => w.name === "type"); + + const updateWidgets = (type) => { + const existingWidget = node.widgets.find((w) => w.name === "mmproj_path"); + if (existingWidget) { + existingWidget.hidden = (type !== "qwen_image_edit") + } + const newSize = node.computeSize(); + node.size = newSize; + app.graph.setDirtyCanvas(true, true); + }; + typeWidget.callback = (value) => { + updateWidgets(value); + }; + setTimeout(() => { + updateWidgets(typeWidget.value); + }, 1); + }; + } + }, +}); diff --git a/loader.py b/loader.py index fd35e13..6155bf9 100644 --- a/loader.py +++ b/loader.py @@ -210,7 +210,7 @@ def strip_quant_suffix(name): name = name[:match.start()] return name -def gguf_mmproj_loader(path): +def gguf_mmproj_loader(path, mmproj_path): # Reverse version of Qwen2VLVisionModel.modify_tensors logging.info("Attenpting to find mmproj file for text encoder...") @@ -219,27 +219,10 @@ def gguf_mmproj_loader(path): tenc = os.path.splitext(tenc_fname)[0].lower() tenc = strip_quant_suffix(tenc) - # try and find matching mmproj - target = [] - root = os.path.dirname(path) - for fname in os.listdir(root): - name, ext = os.path.splitext(fname) - if ext.lower() != ".gguf": - continue - if "mmproj" not in name.lower(): - continue - if tenc in name.lower(): - target.append(fname) - - if len(target) == 0: + if mmproj_path is None or not os.path.exists(mmproj_path): logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!") return {} - if len(target) > 1: - logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.") - - logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.") - target = os.path.join(root, target[0]) - vsd = gguf_sd_loader(target, is_text_model=True) + vsd = gguf_sd_loader(mmproj_path, is_text_model=True) # concat 4D to 5D if "v.patch_embd.weight.1" in vsd: @@ -324,7 +307,7 @@ def gguf_tokenizer_loader(path, temb_shape): del reader return torch.ByteTensor(list(spm.SerializeToString())) -def gguf_clip_loader(path): +def gguf_clip_loader(path, **kwargs): sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) if arch in {"t5", "t5encoder"}: temb_key = "token_embd.weight" @@ -346,7 +329,8 @@ def gguf_clip_loader(path): if arch == "llama": sd = llama_permute(sd, 32, 8) # L3 if arch == "qwen2vl": - vsd = gguf_mmproj_loader(path) + mmproj_path = kwargs.get("mmproj_path", None) + vsd = gguf_mmproj_loader(path, mmproj_path) sd.update(vsd) else: pass diff --git a/nodes.py b/nodes.py index 4159142..78995a2 100644 --- a/nodes.py +++ b/nodes.py @@ -181,7 +181,10 @@ def INPUT_TYPES(s): return { "required": { "clip_name": (s.get_filename_list(),), - "type": base["required"]["type"], + "type": [(*base["required"]["type"][0], "qwen_image_edit",)], + }, + "optional": { + "mmproj_path": (s.get_filename_list(),) } } @@ -197,11 +200,15 @@ def get_filename_list(s): files += folder_paths.get_filename_list("clip_gguf") return sorted(files) - def load_data(self, ckpt_paths): + def load_data(self, ckpt_paths, type, **kwargs): clip_data = [] for p in ckpt_paths: if p.endswith(".gguf"): - sd = gguf_clip_loader(p) + if type == "qwen_image_edit": + mmproj_path = kwargs.get("mmproj_path", None) + sd = gguf_clip_loader(p, mmproj_path=mmproj_path) + else: + sd = gguf_clip_loader(p) else: sd = comfy.utils.load_torch_file(p, safe_load=True) if "scaled_fp8" in sd: # NOTE: Scaled FP8 would require different custom ops, but only one can be active @@ -222,10 +229,10 @@ def load_patcher(self, clip_paths, clip_type, clip_data): clip.patcher = GGUFModelPatcher.clone(clip.patcher) return clip - def load_clip(self, clip_name, type="stable_diffusion"): + def load_clip(self, clip_name, type="stable_diffusion", **kwargs): clip_path = folder_paths.get_full_path("clip", clip_name) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) - return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path])),) + return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path], type, **kwargs)),) class DualCLIPLoaderGGUF(CLIPLoaderGGUF): @classmethod