Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
from .nodes import NODE_CLASS_MAPPINGS
NODE_DISPLAY_NAME_MAPPINGS = {k:v.TITLE for k,v in NODE_CLASS_MAPPINGS.items()}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']


WEB_DIRECTORY = "./js"
33 changes: 33 additions & 0 deletions js/clip_loader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { app } from "../../scripts/app.js"

app.registerExtension({
name: "ComfyUI_GGUF.clip_loader_gguf",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "CLIPLoaderGGUF") {
const onNodeCreated = nodeType.prototype.onNodeCreated;

nodeType.prototype.onNodeCreated = function () {
onNodeCreated?.apply(this, arguments);
const node = this;

const typeWidget = node.widgets.find((w) => w.name === "type");

const updateWidgets = (type) => {
const existingWidget = node.widgets.find((w) => w.name === "mmproj_path");
if (existingWidget) {
existingWidget.hidden = (type !== "qwen_image_edit")
}
const newSize = node.computeSize();
node.size = newSize;
app.graph.setDirtyCanvas(true, true);
};
typeWidget.callback = (value) => {
updateWidgets(value);
};
setTimeout(() => {
updateWidgets(typeWidget.value);
}, 1);
};
}
},
});
28 changes: 6 additions & 22 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def strip_quant_suffix(name):
name = name[:match.start()]
return name

def gguf_mmproj_loader(path):
def gguf_mmproj_loader(path, mmproj_path):
# Reverse version of Qwen2VLVisionModel.modify_tensors
logging.info("Attenpting to find mmproj file for text encoder...")

Expand All @@ -219,27 +219,10 @@ def gguf_mmproj_loader(path):
tenc = os.path.splitext(tenc_fname)[0].lower()
tenc = strip_quant_suffix(tenc)

# try and find matching mmproj
target = []
root = os.path.dirname(path)
for fname in os.listdir(root):
name, ext = os.path.splitext(fname)
if ext.lower() != ".gguf":
continue
if "mmproj" not in name.lower():
continue
if tenc in name.lower():
target.append(fname)

if len(target) == 0:
if mmproj_path is None or not os.path.exists(mmproj_path):
logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
return {}
if len(target) > 1:
logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")

logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
target = os.path.join(root, target[0])
vsd = gguf_sd_loader(target, is_text_model=True)
vsd = gguf_sd_loader(mmproj_path, is_text_model=True)

# concat 4D to 5D
if "v.patch_embd.weight.1" in vsd:
Expand Down Expand Up @@ -324,7 +307,7 @@ def gguf_tokenizer_loader(path, temb_shape):
del reader
return torch.ByteTensor(list(spm.SerializeToString()))

def gguf_clip_loader(path):
def gguf_clip_loader(path, **kwargs):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
if arch in {"t5", "t5encoder"}:
temb_key = "token_embd.weight"
Expand All @@ -346,7 +329,8 @@ def gguf_clip_loader(path):
if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3
if arch == "qwen2vl":
vsd = gguf_mmproj_loader(path)
mmproj_path = kwargs.get("mmproj_path", None)
vsd = gguf_mmproj_loader(path, mmproj_path)
sd.update(vsd)
else:
pass
Expand Down
17 changes: 12 additions & 5 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def INPUT_TYPES(s):
return {
"required": {
"clip_name": (s.get_filename_list(),),
"type": base["required"]["type"],
"type": [(*base["required"]["type"][0], "qwen_image_edit",)],
},
"optional": {
"mmproj_path": (s.get_filename_list(),)
}
}

Expand All @@ -197,11 +200,15 @@ def get_filename_list(s):
files += folder_paths.get_filename_list("clip_gguf")
return sorted(files)

def load_data(self, ckpt_paths):
def load_data(self, ckpt_paths, type, **kwargs):
clip_data = []
for p in ckpt_paths:
if p.endswith(".gguf"):
sd = gguf_clip_loader(p)
if type == "qwen_image_edit":
mmproj_path = kwargs.get("mmproj_path", None)
sd = gguf_clip_loader(p, mmproj_path=mmproj_path)
else:
sd = gguf_clip_loader(p)
else:
sd = comfy.utils.load_torch_file(p, safe_load=True)
if "scaled_fp8" in sd: # NOTE: Scaled FP8 would require different custom ops, but only one can be active
Expand All @@ -222,10 +229,10 @@ def load_patcher(self, clip_paths, clip_type, clip_data):
clip.patcher = GGUFModelPatcher.clone(clip.patcher)
return clip

def load_clip(self, clip_name, type="stable_diffusion"):
def load_clip(self, clip_name, type="stable_diffusion", **kwargs):
clip_path = folder_paths.get_full_path("clip", clip_name)
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path])),)
return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path], type, **kwargs)),)

class DualCLIPLoaderGGUF(CLIPLoaderGGUF):
@classmethod
Expand Down