diff --git a/loader.py b/loader.py index 1948027..e96157b 100644 --- a/loader.py +++ b/loader.py @@ -5,12 +5,13 @@ import gguf import re import os +from pathlib import Path from .ops import GGMLTensor from .dequant import is_quantized, dequantize_tensor IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"} -TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl"} +TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"} VIS_TYPE_LIST = {"clip-vision", "mmproj"} def get_orig_shape(reader, tensor_name): @@ -177,6 +178,22 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal "output.weight": "lm_head.weight", } +GEMMA3_SD_MAP = { + "blk.": "model.layers.", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_k": "self_attn.k_proj", + "attn_v": "self_attn.v_proj", + "attn_output": "self_attn.o_proj", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "token_embd": "model.embed_tokens", + "output_norm": "model.norm", + "output.weight": "lm_head.weight", +} + CLIP_VISION_SD_MAP = { "mm.": "visual.merger.mlp.", "v.post_ln.": "visual.merger.ln_q.", @@ -217,18 +234,30 @@ def strip_quant_suffix(name): name = name[:match.start()] return name -def gguf_mmproj_loader(path): +def gguf_mmproj_loader(path, arch: str = ""): # Reverse version of Qwen2VLVisionModel.modify_tensors - logging.info("Attenpting to find mmproj file for text encoder...") + logging.info("Attempting to find mmproj file for text encoder...") # get name to match w/o quant suffix tenc_fname = os.path.basename(path) tenc = os.path.splitext(tenc_fname)[0].lower() tenc = strip_quant_suffix(tenc) - + # try and find matching mmproj target = [] root = os.path.dirname(path) + + # Look for expected gemma3 mmproj file + if arch == "gemma3": + mmproj_path = next( + (str(f) for f in Path(root).glob("*.gguf") + if "gemma" in f.name.lower() and "12b" in f.name.lower() and "mmproj" in f.name.lower()), + None + ) + if mmproj_path: + target.append(mmproj_path) + + # Or look for one sharing same name as root gguf for fname in os.listdir(root): name, ext = os.path.splitext(fname) if ext.lower() != ".gguf": @@ -239,14 +268,14 @@ def gguf_mmproj_loader(path): target.append(fname) if len(target) == 0: - logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!") + logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Vision features will be broken!") return {} if len(target) > 1: logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.") logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.") target = os.path.join(root, target[0]) - vsd = gguf_sd_loader(target, is_text_model=True) + vsd, _ = gguf_sd_loader(target, is_text_model=True) # concat 4D to 5D if "v.patch_embd.weight.1" in vsd: @@ -374,8 +403,44 @@ def gguf_tekken_tokenizer_loader(path, temb_shape): del reader return torch.ByteTensor(list(json.dumps(data).encode('utf-8'))) +def load_spiece_from_safetensor(gguf_path): + """Try to load spiece_model from a safetensor file in the same directory.""" + try: + from safetensors import safe_open + except ImportError: + logging.warning("safetensors not available, cannot load external spiece_model") + return None + + from pathlib import Path + + directory = os.path.dirname(gguf_path) + if not directory: + directory = "." + + basename = os.path.splitext(os.path.basename(gguf_path))[0] + basename = strip_quant_suffix(basename).lower() + + # Find all .safetensors files and filter for tokenizer/spiece patterns + path = Path(directory) + for safetensor_file in path.glob("*.safetensors"): + name_lower = safetensor_file.name.lower() + # Check if it matches our patterns + if not (name_lower.startswith(basename) and ('tokenizer' in name_lower or 'spiece' in name_lower)) \ + and not ('tokenizer' in name_lower or 'spiece' in name_lower): + continue + + try: + with safe_open(str(safetensor_file), framework="pt", device="cpu") as f: + if "spiece_model" in f.keys(): + logging.info(f"Loading spiece_model from {safetensor_file.name}") + return f.get_tensor("spiece_model") + except Exception as e: + logging.warning(f"Failed to load spiece_model from {safetensor_file.name}: {e}") + + return None + def gguf_clip_loader(path): - sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) + sd, arch, metadata = gguf_sd_loader(path, return_arch=True, is_text_model=True) if arch in {"t5", "t5encoder"}: temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape == (256384, 4096): @@ -385,7 +450,7 @@ def gguf_clip_loader(path): logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) sd = sd_map_replace(sd, T5_SD_MAP) - elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl"}: + elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}: # TODO: pass model_options["vocab_size"] to loader somehow temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024): @@ -395,12 +460,29 @@ def gguf_clip_loader(path): # See note above for T5. logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) - sd = sd_map_replace(sd, LLAMA_SD_MAP) - if arch == "llama": - sd = llama_permute(sd, 32, 8) # L3 / Mistral - if arch == "qwen2vl": - vsd = gguf_mmproj_loader(path) + + # Apply appropriate key mapping + if arch == "gemma3": + sd = sd_map_replace(sd, GEMMA3_SD_MAP) + # Gemma-3 uses same head/kv_head counts as config shows + sd = llama_permute(sd, 16, 8) # From config: num_attention_heads=16, num_key_value_heads=8 + else: + sd = sd_map_replace(sd, LLAMA_SD_MAP) + if arch == "llama": + sd = llama_permute(sd, 32, 8) # L3 / Mistral + + # Load mmproj for vision models + if arch in {"qwen2vl", "gemma3"}: + vsd = gguf_mmproj_loader(path, arch) sd.update(vsd) + + # Check if spiece_model is needed but missing + if arch == "gemma3" and "spiece_model" not in sd: + spiece_tensor = load_spiece_from_safetensor(path) + if spiece_tensor is not None: + sd["spiece_model"] = spiece_tensor + else: + logging.warning("spiece_model not found in GGUF or safetensor files. Tokenizer may not work correctly.") else: pass - return sd + return sd \ No newline at end of file