diff --git a/tools/convert.py b/tools/convert.py index 5029c87..e81ddd5 100644 --- a/tools/convert.py +++ b/tools/convert.py @@ -15,29 +15,27 @@ class ModelTemplate: arch = "invalid" # string describing architecture shape_fix = False # whether to reshape tensors + ndims_fix = False # whether to save fix file for tensors exceeding max dims keys_detect = [] # list of lists to match in state dict keys_banned = [] # list of keys that should mark model as invalid for conversion keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason keys_ignore = [] # list of strings to ignore keys by when found - def handle_nd_tensor(self, key, data): - raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})") - class ModelFlux(ModelTemplate): arch = "flux" keys_detect = [ - ("transformer_blocks.0.attn.norm_added_k.weight",), + ("single_transformer_blocks.0.attn.norm_k.weight",), ("double_blocks.0.img_attn.proj.weight",), ] - keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",] + keys_banned = ["single_transformer_blocks.0.attn.norm_k.weight",] class ModelSD3(ModelTemplate): arch = "sd3" keys_detect = [ - ("transformer_blocks.0.attn.add_q_proj.weight",), + ("transformer_blocks.0.ff_context.net.0.proj.weight",), ("joint_blocks.0.x_block.attn.qkv.weight",), ] - keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",] + keys_banned = ["transformer_blocks.0.ff_context.net.0.proj.weight",] class ModelAura(ModelTemplate): arch = "aura" @@ -61,7 +59,7 @@ class ModelHiDream(ModelTemplate): "img_emb.emb_pos" ] -class CosmosPredict2(ModelTemplate): +class ModelCosmosPredict2(ModelTemplate): arch = "cosmos" keys_detect = [ ( @@ -72,8 +70,19 @@ class CosmosPredict2(ModelTemplate): keys_hiprec = ["pos_embedder"] keys_ignore = ["_extra_state", "accum_"] +class ModelQwenImage(ModelTemplate): + arch = "qwen_image" + keys_detect = [ + ( + "time_text_embed.timestep_embedder.linear_2.weight", + "transformer_blocks.0.attn.norm_added_q.weight", + "transformer_blocks.0.img_mlp.net.0.proj.weight", + ) + ] + class ModelHyVid(ModelTemplate): arch = "hyvid" + ndims_fix = True keys_detect = [ ( "double_blocks.0.img_attn_proj.weight", @@ -81,17 +90,9 @@ class ModelHyVid(ModelTemplate): ) ] - def handle_nd_tensor(self, key, data): - # hacky but don't have any better ideas - path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here?? - if os.path.isfile(path): - raise RuntimeError(f"5D tensor fix file already exists! {path}") - fsd = {key: torch.from_numpy(data)} - tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}") - save_file(fsd, path) - -class ModelWan(ModelHyVid): +class ModelWan(ModelTemplate): arch = "wan" + ndims_fix = True keys_detect = [ ( "blocks.0.self_attn.norm_q.weight", @@ -100,7 +101,11 @@ class ModelWan(ModelHyVid): ) ] keys_hiprec = [ - ".modulation" # nn.parameter, can't load from BF16 ver + ".modulation", # nn.parameter, can't load from BF16 ver + ".encoder.padding_tokens", # nn.parameter, specific to S2V + "trainable_cond_mask", # used directly w/ .weight + "casual_audio_encoder.weights", # nn.parameter, specific to S2V + "casual_audio_encoder.encoder.conv", # CausalConv1d doesn't use ops.py for now ] class ModelLTXV(ModelTemplate): @@ -144,9 +149,17 @@ class ModelLumina2(ModelTemplate): keys_detect = [ ("cap_embedder.1.weight", "context_refiner.0.attention.qkv.weight") ] + keys_hiprec = [ + # Z-Image specific + "x_pad_token", + "cap_pad_token", + ] -arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2, - ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1, ModelLumina2] +# The architectures are checked in order and the first successful match terminates the search. +arch_list = [ + ModelFlux, ModelSD3, ModelAura, ModelHiDream, ModelCosmosPredict2, ModelQwenImage, + ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1, ModelLumina2 +] def is_model_arch(model, state_dict): # check if model is correct @@ -157,7 +170,7 @@ def is_model_arch(model, state_dict): matched = True invalid = any(key in state_dict for key in model.keys_banned) break - assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)" + assert not invalid, f"Model architecture not allowed for conversion! (i.e. reference VS diffusers format) [arch:{model.arch}]" return matched def detect_arch(state_dict): @@ -210,6 +223,24 @@ def strip_prefix(state_dict): return sd +def find_main_dtype(state_dict, allow_fp32=False): + # detect most common dtype in input + dtypes = [x.dtype for x in state_dict.values()] + dtypes = {x:dtypes.count(x) for x in set(dtypes)} + main_dtype = max(dtypes, key=dtypes.get) + + if main_dtype == torch.bfloat16: + ftype_name = "BF16" + ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16 + elif main_dtype == torch.float32 and allow_fp32: + ftype_name = "F32" + ftype_gguf = gguf.LlamaFileType.ALL_F32 + else: + ftype_name = "F16" + ftype_gguf = gguf.LlamaFileType.MOSTLY_F16 + + return ftype_name, ftype_gguf + def load_state_dict(path): if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]): state_dict = torch.load(path, map_location="cpu", weights_only=True) @@ -224,7 +255,7 @@ def load_state_dict(path): return strip_prefix(state_dict) -def handle_tensors(writer, state_dict, model_arch): +def handle_tensors(writer, state_dict, model_arch, allow_fp32=False): name_lengths = tuple(sorted( ((key, len(key)) for key in state_dict.keys()), key=lambda item: item[1], @@ -233,9 +264,13 @@ def handle_tensors(writer, state_dict, model_arch): if not name_lengths: return max_name_len = name_lengths[0][1] + if max_name_len > MAX_TENSOR_NAME_LENGTH: bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH) raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}") + + invalid_tensors = {} + quantized_tensors = {} for key, data in tqdm(state_dict.items()): old_dtype = data.dtype @@ -255,14 +290,14 @@ def handle_tensors(writer, state_dict, model_arch): data_shape = data.shape if old_dtype == torch.bfloat16: data_qtype = gguf.GGMLQuantizationType.BF16 - # elif old_dtype == torch.float32: - # data_qtype = gguf.GGMLQuantizationType.F32 + elif old_dtype == torch.float32 and allow_fp32: + data_qtype = gguf.GGMLQuantizationType.F32 else: data_qtype = gguf.GGMLQuantizationType.F16 # The max no. of dimensions that can be handled by the quantization code is 4 if len(data.shape) > MAX_TENSOR_DIMS: - model_arch.handle_nd_tensor(key, data) + invalid_tensors[key] = data continue # needs to be added back later # get number of parameters (AKA elements) in this tensor @@ -296,38 +331,27 @@ def handle_tensors(writer, state_dict, model_arch): try: data = gguf.quants.quantize(data, data_qtype) + quantized_tensors[key] = data_qtype except (AttributeError, gguf.QuantError) as e: tqdm.write(f"falling back to F16: {e}") data_qtype = gguf.GGMLQuantizationType.F16 data = gguf.quants.quantize(data, data_qtype) - - new_name = key # do we need to rename? + quantized_tensors[key] = data_qtype shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" - tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{key}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + writer.add_tensor(key, data, raw_dtype=data_qtype) - writer.add_tensor(new_name, data, raw_dtype=data_qtype) + return quantized_tensors, invalid_tensors -def convert_file(path, dst_path=None, interact=True, overwrite=False): +def convert_file(path, dst_path=None, interact=True, overwrite=False, allow_fp32=False): # load & run model detection logic state_dict = load_state_dict(path) model_arch = detect_arch(state_dict) logging.info(f"* Architecture detected from input: {model_arch.arch}") - # detect & set dtype for output file - dtypes = [x.dtype for x in state_dict.values()] - dtypes = {x:dtypes.count(x) for x in set(dtypes)} - main_dtype = max(dtypes, key=dtypes.get) - - if main_dtype == torch.bfloat16: - ftype_name = "BF16" - ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16 - # elif main_dtype == torch.float32: - # ftype_name = "F32" - # ftype_gguf = None - else: - ftype_name = "F16" - ftype_gguf = gguf.LlamaFileType.MOSTLY_F16 + ftype_name, ftype_gguf = find_main_dtype(state_dict, allow_fp32=allow_fp32) if dst_path is None: dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf" @@ -346,20 +370,32 @@ def convert_file(path, dst_path=None, interact=True, overwrite=False): if ftype_gguf is not None: writer.add_file_type(ftype_gguf) - handle_tensors(writer, state_dict, model_arch) + quantized_tensors, invalid_tensors = handle_tensors(writer, state_dict, model_arch, allow_fp32=allow_fp32) + if len(invalid_tensors) > 0: + if not model_arch.ndims_fix: # only applies to 5D fix for now, possibly expand to cover more cases? + raise ValueError(f"Tensor(s) detected that exceeds dims supported by C++ code! ({invalid_tensors.keys()})") + + fix_path = os.path.join( + os.path.dirname(dst_path), + f"fix_5d_tensors_{model_arch.arch}.safetensors" + ) + if os.path.isfile(fix_path): + raise RuntimeError(f"Tensor fix file already exists! {path}") + + invalid_tensors = {k:torch.from_numpy(v.copy()) for k,v in invalid_tensors.items()} + save_file(invalid_tensors, fix_path) + logging.warning(f"\n### Warning! Fix file found at '{fix_path}'") + logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.") + else: + fix_path = None + writer.write_header_to_file(path=dst_path) writer.write_kv_data_to_file() writer.write_tensors_to_file(progress=True) writer.close() - fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors" - if os.path.isfile(fix): - logging.warning(f"\n### Warning! Fix file found at '{fix}'") - logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.") - - return dst_path, model_arch + return dst_path, model_arch, fix_path if __name__ == "__main__": args = parse_args() convert_file(args.src, args.dst) - diff --git a/tools/fix_5d_tensors.py b/tools/fix_5d_tensors.py index 0e61d1c..0826638 100644 --- a/tools/fix_5d_tensors.py +++ b/tools/fix_5d_tensors.py @@ -30,23 +30,21 @@ def get_file_type(reader): ft = int(field.parts[field.data[-1]]) return gguf.LlamaFileType(ft) -if __name__ == "__main__": - args = get_args() - +def apply_5d_fix(src, dst, fix=None, overwrite=False): # read existing - reader = gguf.GGUFReader(args.src) + reader = gguf.GGUFReader(src) arch = get_arch_str(reader) file_type = get_file_type(reader) print(f"Detected arch: '{arch}' (ftype: {str(file_type)})") # prep fix - if args.fix is None: - args.fix = f"./fix_5d_tensors_{arch}.safetensors" + if fix is None: + fix = f"./fix_5d_tensors_{arch}.safetensors" - if not os.path.isfile(args.fix): - raise OSError(f"No 5D tensor fix file: {args.fix}") + if not os.path.isfile(fix): + raise OSError(f"No 5D tensor fix file: {fix}") - sd5d = load_file(args.fix) + sd5d = load_file(fix) sd5d = {k:v.numpy() for k,v in sd5d.items()} print("5D tensors:", sd5d.keys()) @@ -55,6 +53,7 @@ def get_file_type(reader): writer.add_quantization_version(gguf.GGML_QUANT_VERSION) writer.add_file_type(file_type) + global added added = [] def add_extra_key(writer, key, data): global added @@ -76,7 +75,11 @@ def add_extra_key(writer, key, data): if key not in added: add_extra_key(writer, key, data) - writer.write_header_to_file(path=args.dst) + writer.write_header_to_file(path=dst) writer.write_kv_data_to_file() writer.write_tensors_to_file(progress=True) writer.close() + +if __name__ == "__main__": + args = get_args() + apply_5d_fix(args.src, args.dst, fix=args.fix, overwrite=args.overwrite) diff --git a/tools/lcpp.patch b/tools/lcpp.patch index 92396e1..a341a4d 100644 --- a/tools/lcpp.patch +++ b/tools/lcpp.patch @@ -39,10 +39,10 @@ index b16c462f..6d1568f1 100644 const int idx = gguf_find_tensor(ctx, name); if (idx < 0) { diff --git a/src/llama.cpp b/src/llama.cpp -index 24e1f1f0..25db4c69 100644 +index 24e1f1f0..ee68edfd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp -@@ -205,6 +205,17 @@ enum llm_arch { +@@ -205,6 +205,18 @@ enum llm_arch { LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, @@ -57,10 +57,11 @@ index 24e1f1f0..25db4c69 100644 + LLM_ARCH_HIDREAM, + LLM_ARCH_COSMOS, + LLM_ARCH_LUMINA2, ++ LLM_ARCH_QWEN_IMAGE, LLM_ARCH_UNKNOWN, }; -@@ -258,6 +269,17 @@ static const std::map LLM_ARCH_NAMES = { +@@ -258,6 +270,18 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -75,28 +76,30 @@ index 24e1f1f0..25db4c69 100644 + { LLM_ARCH_HIDREAM, "hidream" }, + { LLM_ARCH_COSMOS, "cosmos" }, + { LLM_ARCH_LUMINA2, "lumina2" }, ++ { LLM_ARCH_QWEN_IMAGE, "qwen_image" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; -@@ -1531,6 +1553,17 @@ static const std::map> LLM_TENSOR_N +@@ -1531,6 +1555,18 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, -+ { LLM_ARCH_FLUX, {}}, -+ { LLM_ARCH_SD1, {}}, -+ { LLM_ARCH_SDXL, {}}, -+ { LLM_ARCH_SD3, {}}, -+ { LLM_ARCH_AURA, {}}, -+ { LLM_ARCH_LTXV, {}}, -+ { LLM_ARCH_HYVID, {}}, -+ { LLM_ARCH_WAN, {}}, -+ { LLM_ARCH_HIDREAM, {}}, -+ { LLM_ARCH_COSMOS, {}}, -+ { LLM_ARCH_LUMINA2, {}}, ++ { LLM_ARCH_FLUX, {}}, ++ { LLM_ARCH_SD1, {}}, ++ { LLM_ARCH_SDXL, {}}, ++ { LLM_ARCH_SD3, {}}, ++ { LLM_ARCH_AURA, {}}, ++ { LLM_ARCH_LTXV, {}}, ++ { LLM_ARCH_HYVID, {}}, ++ { LLM_ARCH_WAN, {}}, ++ { LLM_ARCH_HIDREAM, {}}, ++ { LLM_ARCH_COSMOS, {}}, ++ { LLM_ARCH_LUMINA2, {}}, ++ { LLM_ARCH_QWEN_IMAGE, {}}, { LLM_ARCH_UNKNOWN, { -@@ -5403,6 +5436,25 @@ static void llm_load_hparams( +@@ -5403,6 +5439,26 @@ static void llm_load_hparams( // get general kv ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); @@ -113,6 +116,7 @@ index 24e1f1f0..25db4c69 100644 + case LLM_ARCH_HIDREAM: + case LLM_ARCH_COSMOS: + case LLM_ARCH_LUMINA2: ++ case LLM_ARCH_QWEN_IMAGE: + model.ftype = ml.ftype; + return; + default: @@ -122,7 +126,7 @@ index 24e1f1f0..25db4c69 100644 // get hparams kv ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); -@@ -18016,6 +18068,134 @@ static void llama_tensor_dequantize_internal( +@@ -18016,6 +18072,158 @@ static void llama_tensor_dequantize_internal( workers.clear(); } @@ -158,6 +162,7 @@ index 24e1f1f0..25db4c69 100644 + (name.find(".v.weight") != std::string::npos) || + (name.find(".attn.w1v.weight") != std::string::npos) || + (name.find(".attn.w2v.weight") != std::string::npos) || ++ (name.find(".add_v_proj.weight") != std::string::npos) || + (name.find("_attn.v_proj.weight") != std::string::npos) + ){ + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { @@ -197,7 +202,9 @@ index 24e1f1f0..25db4c69 100644 + (name.find(".ff.net.2.weight") != std::string::npos) || + (name.find(".mlp.layer2.weight") != std::string::npos) || + (name.find(".adaln_modulation_mlp.2.weight") != std::string::npos) || -+ (name.find(".feed_forward.w2.weight") != std::string::npos) ++ (name.find(".feed_forward.w2.weight") != std::string::npos) || ++ (name.find(".img_mlp.net.2.weight") != std::string::npos) || ++ (name.find(".txt_mlp.net.2.weight") != std::string::npos) + ) { + // TODO: add back `layer_info` with some model specific logic + logic further down + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { @@ -224,6 +231,27 @@ index 24e1f1f0..25db4c69 100644 + ++qs.i_ffn_down; + } + ++ // first/last block high precision test ++ if (arch == LLM_ARCH_QWEN_IMAGE){ ++ if ( ++ (name.find("transformer_blocks.0.") != std::string::npos) || ++ (name.find("transformer_blocks.59.") != std::string::npos) // this should be dynamic ++ ) { ++ if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { ++ new_type = GGML_TYPE_Q4_K; ++ } ++ else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { ++ new_type = GGML_TYPE_Q4_K; ++ } ++ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { ++ new_type = GGML_TYPE_Q5_K; ++ } ++ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) { ++ new_type = GGML_TYPE_Q6_K; ++ } ++ } ++ } ++ + // Sanity check for row shape + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || @@ -257,7 +285,7 @@ index 24e1f1f0..25db4c69 100644 static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { const std::string name = ggml_get_name(tensor); -@@ -18513,7 +18693,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s +@@ -18513,7 +18721,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } @@ -268,7 +296,7 @@ index 24e1f1f0..25db4c69 100644 } size_t total_size_org = 0; -@@ -18547,6 +18729,51 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s +@@ -18547,6 +18757,71 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s ctx_outs[i_split] = gguf_init_empty(); } gguf_add_tensor(ctx_outs[i_split], tensor); @@ -295,6 +323,20 @@ index 24e1f1f0..25db4c69 100644 + LLAMA_LOG_INFO("\n%s: Correcting register_tokens shape for AuraFlow: [key:%s]\n", __func__, tensor->name); + } + } ++ // not used for actual lumina, Z-Image specific ++ if (model.arch == LLM_ARCH_LUMINA2) { ++ const std::string name = ggml_get_name(tensor); ++ if (name == "x_pad_token" && tensor->ne[1] == 1) { ++ const int n_dim = 2; ++ gguf_set_tensor_ndim(ctx_outs[i_split], "x_pad_token", n_dim); ++ LLAMA_LOG_INFO("\n%s: Correcting x_pad_token shape for Z-Image: [key:%s]\n", __func__, tensor->name); ++ } ++ if (name == "cap_pad_token" && tensor->ne[1] == 1) { ++ const int n_dim = 2; ++ gguf_set_tensor_ndim(ctx_outs[i_split], "cap_pad_token", n_dim); ++ LLAMA_LOG_INFO("\n%s: Correcting cap_pad_token shape for Z-Image: [key:%s]\n", __func__, tensor->name); ++ } ++ } + // conv3d fails due to max dims - unsure what to do here as we never even reach this check + if (model.arch == LLM_ARCH_HYVID) { + const std::string name = ggml_get_name(tensor); @@ -316,11 +358,17 @@ index 24e1f1f0..25db4c69 100644 + gguf_set_tensor_ndim(ctx_outs[i_split], tensor->name, n_dim); + LLAMA_LOG_INFO("\n%s: Correcting shape for Wan FLF2V: [key:%s]\n", __func__, tensor->name); + } ++ // S2V model only ++ if (name == "casual_audio_encoder.weights" || name == "casual_audio_encoder.encoder.padding_tokens") { ++ const int n_dim = 4; ++ gguf_set_tensor_ndim(ctx_outs[i_split], tensor->name, n_dim); ++ LLAMA_LOG_INFO("\n%s: Correcting shape for Wan S2V: [key:%s]\n", __func__, tensor->name); ++ } + } } // Set split info if needed -@@ -18647,6 +18874,110 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s +@@ -18647,6 +18922,131 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; @@ -334,6 +382,9 @@ index 24e1f1f0..25db4c69 100644 + quantize &= name.find("vector_in.") == std::string::npos; + quantize &= name.find("guidance_in.") == std::string::npos; + quantize &= name.find("final_layer.") == std::string::npos; ++ // flux2 ++ quantize &= name.find("single_stream_modulation.") == std::string::npos; ++ quantize &= name.find("double_stream_modulation_") == std::string::npos; + } + if (model.arch == LLM_ARCH_SD1 || model.arch == LLM_ARCH_SDXL) { + image_model = true; @@ -395,6 +446,12 @@ index 24e1f1f0..25db4c69 100644 + quantize &= name.find("time_embedding.") == std::string::npos; + quantize &= name.find("img_emb.") == std::string::npos; + quantize &= name.find("head.") == std::string::npos; ++ // S2V ++ quantize &= name.find("cond_encoder.") == std::string::npos; ++ quantize &= name.find("frame_packer.") == std::string::npos; ++ quantize &= name.find("audio_injector.") == std::string::npos; ++ quantize &= name.find("casual_audio_encoder.") == std::string::npos; ++ quantize &= name.find("trainable_cond_mask.") == std::string::npos; + } + if (model.arch == LLM_ARCH_HIDREAM) { + image_model = true; @@ -422,6 +479,18 @@ index 24e1f1f0..25db4c69 100644 + quantize &= name.find("cap_embedder.") == std::string::npos; + quantize &= name.find("context_refiner.") == std::string::npos; + quantize &= name.find("noise_refiner.") == std::string::npos; ++ // z-image ++ quantize &= name.find("x_pad_token.") == std::string::npos; ++ quantize &= name.find("cap_pad_token.") == std::string::npos; ++ ++ } ++ if (model.arch == LLM_ARCH_QWEN_IMAGE) { ++ image_model = true; ++ quantize &= name.find("img_in.") == std::string::npos; ++ quantize &= name.find("txt_in.") == std::string::npos; ++ quantize &= name.find("time_text_embed.") == std::string::npos; ++ quantize &= name.find("proj_out.") == std::string::npos; ++ quantize &= name.find("norm_out.") == std::string::npos; + } + // ignore 3D/4D tensors for image models as the code was never meant to handle these + if (image_model) { @@ -431,7 +500,7 @@ index 24e1f1f0..25db4c69 100644 enum ggml_type new_type; void * new_data; size_t new_size; -@@ -18655,6 +18986,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s +@@ -18655,6 +19055,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = default_type; // get more optimal quantization type based on the tensor shape, layer, etc. @@ -441,7 +510,7 @@ index 24e1f1f0..25db4c69 100644 if (!params->pure && ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); } -@@ -18664,6 +18998,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s +@@ -18664,6 +19067,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } diff --git a/tools/tool_auto.py b/tools/tool_auto.py new file mode 100644 index 0000000..6f980c1 --- /dev/null +++ b/tools/tool_auto.py @@ -0,0 +1,374 @@ +# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +import os +import re +import sys +import time +import torch +import logging +import argparse +import subprocess +import huggingface_hub as hf + +logging.getLogger().setLevel(logging.DEBUG) + +qtypes =[ + # "F16", "BF16", + "Q8_0", "Q6_K", + "Q5_K_M", "Q5_K_S", "Q5_1", "Q5_0", + "Q4_K_M", "Q4_K_S", "Q4_1", "Q4_0", + "Q3_K_M", "Q3_K_S", "Q2_K" +] + +dtype_dict = { + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "F8_E4M3": getattr(torch, "float8_e4m3fn", "_invalid"), + "F8_E5M2": getattr(torch, "float8_e5m2", "_invalid"), +} + +# this is pretty jank but I want to be able to run it on a blank instance w/o setup +terraform_dict = { + "repo": "city96/ComfyUI-GGUF", + "target": "auto_convert", + "lcpp_repo": "ggerganov/llama.cpp", + "lcpp_target": "tags/b3962", +} + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", required=True, help="Source model file or huggingface repo name") + parser.add_argument("--quants", nargs="+", choices=["all", "base", *qtypes], default=["Q8_0"]) + parser.add_argument("--output-dir", default=None, help="Location for output files, defaults to current dir or ComfyUI model dir.") + parser.add_argument("--temp-dir", default=None, help="Location for temp files, defaults to [output_dir]/tmp") + parser.add_argument("--force-update", action="store_true", help="Force update & rebuild entire quantization stack.") + parser.add_argument("--resume", action="store_true", help="Skip over existing files. Will NOT check for broken/interrupted files.") + + args = parser.parse_args() + if args.output_dir is None: + args.output_dir = get_output_dir() + if args.temp_dir is None: + args.temp_dir = os.path.join(args.output_dir, "tmp") + + if os.path.isdir(args.temp_dir) and len(os.listdir(args.temp_dir)) > 0: + raise OSError("Output temp folder not empty!") + + if "all" in args.quants: + args.quants = ["base", *qtypes] + + return args + +def run_cmd(*args, log_error=False): + logging.debug(f"cmd: {args}") + try: + log = subprocess.run(args, capture_output=True, text=True) + except Exception as e: + logging.warning(f"{args[0]}, {e}") + return -1 + if log.returncode != 0 and log_error: + logging.warning(f"{args[0]}: {log.stdout} {log.stderr}") + else: + logging.debug(f"{args[0]}: {repr(log.stdout)} {repr(log.stderr.strip())} RET:{log.returncode}") + return log.returncode + +def setup_utils(force_update=False): + # get ComfyUI-GGUF if missing, then compile patched llama.cpp if required + root = os.path.dirname(os.path.abspath(__file__)) + root = os.path.normpath(root) + + if os.path.split(root)[1] != "tools": + cg_dir = os.path.join(root, "ComfyUI-GGUF") + if not os.path.isdir(cg_dir): + logging.warning(f"Running outside tools folder! Cloning to {cg_dir}") + run_cmd("git", "clone", f"https://github.com/{terraform_dict['repo']}", cg_dir) + need_update = True + else: + need_update = False + + if force_update or need_update: + if terraform_dict['target']: + logging.info(f"Attemtping to check out ComfyUI-GGUF branch {terraform_dict['target']}") + run_cmd("git", "-C", cg_dir, "checkout", terraform_dict['target']) + + logging.info("Attemtping to git pull ComfyUI-GGUF to latest") + run_cmd("git", "-C", cg_dir, "pull") + + tools_dir = os.path.join(root, "ComfyUI-GGUF", "tools") + sys.path.append(tools_dir) # to make import(s) work + else: + # TODO: Git pull here too? + logging.warning(f"Assuming latest ComfyUI-GGUF. Please git pull & check out branch {terraform_dict['target']} manually!") + tools_dir = root + + if not os.path.isdir(tools_dir): + raise OSError(f"Can't find tools subfoder in ComfyUI-GGUF at {tools_dir}") + + convert_path = os.path.join(tools_dir, "convert.py") + if not os.path.isfile(convert_path): + raise OSError(f"Cannot find convert.py at location: {convert_path}") + + lcpp_path = os.path.join(root, "llama.cpp.auto") # avoid messing with regular dir + if not os.path.isdir(lcpp_path): + logging.info(f"Attemtping to clone llama.cpp repo to {lcpp_path}") + run_cmd("git", "clone", f"https://github.com/{terraform_dict['lcpp_repo']}", lcpp_path) + need_update = True + else: + need_update = False + + if force_update or need_update: + # TODO: check reflog and/or git reset before checkout? + logging.info(f"Attemtping to check out llama.cpp target {terraform_dict['lcpp_target']}") + run_cmd("git", "-C", lcpp_path, "checkout", terraform_dict['lcpp_target']) + + # TODO: git reset before patch? + patch_path = os.path.join(tools_dir, "lcpp.patch") + # patch (probably) has wrong file endings: + logging.info("Converting patch file endings") + with open(patch_path, "rb") as file: + content = file.read().replace(b"\r\n", b"\n") + with open(patch_path, "wb") as file: + file.write(content) + + if run_cmd("git", "-C", lcpp_path, "apply", "--check", "-R", patch_path) != 0: + logging.info("Attemtping to apply patch to llama.cpp repo") + run_cmd("git", "-C", lcpp_path, "apply", patch_path) + else: + logging.info("Patch already applied") + + # using cmake here as llama.cpp switched to it completely for new versions + if os.name == "nt": + bin_path = os.path.join(lcpp_path, "build", "bin", "debug", "llama-quantize.exe") + else: + bin_path = os.path.join(lcpp_path, "build", "bin", "llama-quantize") + + if not os.path.isfile(bin_path) or force_update or need_update: + if run_cmd("cmake", "--version") != 0: + raise RuntimeError("Can't find cmake! Make sure you have a working build environment set up") + + build_path = os.path.join(lcpp_path, "build") + os.makedirs(build_path, exist_ok=True) + logging.info("Attempting to build llama.cpp binary from source") + run_cmd("cmake", "-B", build_path, lcpp_path) + run_cmd("cmake", "--build", build_path, "--config", "Debug", "-j4", "--target", "llama-quantize") + if not os.path.isfile(bin_path): + raise RuntimeError("Build failed! Rerun with --debug to see error log.") + else: + logging.info("Binary already present") + + return bin_path + +def get_output_dir(): + root = os.path.dirname(os.path.abspath(__file__)) + root = os.path.normpath(root) + split = os.path.split(root) + while split[1]: + if split[1] == "ComfyUI": + if os.path.isdir(os.path.join(*split, "models", "unet")): # new + root = os.path.join(*split, "models", "unet", "gguf") + logging.info(f"Found ComfyUI, using model folder: {root}") + return root + + if os.path.isdir(os.path.join(*split, "models", "diffusion_models")): # old + root = os.path.join(*split, "models", "diffusion_models", "gguf") + logging.info(f"Found ComfyUI, using model folder: {root}") + return root + + logging.info("Found ComfyUI, but can't find model folder") + break + + split = os.path.split(split[0]) + + root = os.path.join(root, "models") + logging.info(f"Defaulting to [script dir]/models: {root}") + return root + +def get_hf_fake_sd(repo, path, device=torch.device("meta")): + sd = {} + meta = hf.parse_safetensors_file_metadata(repo, path) + for key, raw in meta.tensors.items(): + shape = tuple(raw.shape) + dtype = dtype_dict.get(raw.dtype, torch.float32) + sd[key] = torch.zeros(shape, dtype=dtype, device=device) + return sd + +def get_hf_file_arch(repo, path): + pattern = r'(\d+)-of-(\d+)' + match = re.search(pattern, path) + + if match: + # we need to load it as multipart + if int(match.group(1)) != 1: + return None + sd = {} + for k in range(int(match.group(2))): + shard_path = path.replace(match.group(1), f"{k+1:0{len(match.group(1))}}") + sd.update(get_hf_fake_sd(repo, shard_path)) + else: + sd = get_hf_fake_sd(repo, path) + + # this should raise an error on failure + sd = strip_prefix(sd) + model_arch = detect_arch(sd) + + # this is for SDXL and SD1.5, I want to overhaul this logic to match sd.cpp eventually + assert not model_arch.shape_fix, "Model uses shape fix (SDXL/SD1) - unsupported for now." + return model_arch.arch + +def get_hf_valid_files(repo): + # TODO: probably tweak this? + MIN_SIZE_GB = 1 + VALID_SRC_EXTS = [".safetensors", ] # ".pt", ".ckpt", ] + meta = hf.model_info(repo, files_metadata=True) + + valid = {} + for file in meta.siblings: + path = file.rfilename + fname = os.path.basename(path) + name, ext = os.path.splitext(fname) + + if ext.lower() not in VALID_SRC_EXTS: + logging.debug(f"Invalid ext: {path} {ext}") + continue + + if file.size / (1024 ** 3) < MIN_SIZE_GB: + logging.debug(f"File too small: {path} {file.size}") + continue + + try: + arch = get_hf_file_arch(repo, path) + except Exception as e: + logging.warning(f"Arch detect fail: {e} ({path})") + else: + if arch is not None: + valid[path] = arch + logging.info(f"Found '{arch}' model at path {path}") + return valid + +def make_base_quant(src, output_dir, temp_dir, final=True, resume=True): + name, ext = os.path.splitext(os.path.basename(src)) + if ext == ".gguf": + logging.info("Input file already in gguf, assuming base quant") + return None, src, None + + name = name.lower() # uncomment to preserve case in all quants + dst_tmp = os.path.join(temp_dir, f"{name}-{{ftype}}.gguf") # ftype is filled in by convert.py + + tmp_path, model_arch, fix_path = convert_file(src, dst_tmp, interact=False, overwrite=False) + dst_path = os.path.join(output_dir, os.path.basename(tmp_path)) + if os.path.isfile(dst_path): + if resume: + logging.warning("Resuming with interrupted base quant, may be incorrect!") + return dst_path, tmp_path, fix_path + raise OSError(f"Output already exists! Clear folder? {dst_path}") + + if fix_path is not None and os.path.isfile(fix_path): + quant_source = tmp_path + if final: + apply_5d_fix(tmp_path, dst_path, fix=fix_path, overwrite=False) + else: + dst_path = None + else: + fix_path = None + if final: + os.rename(tmp_path, dst_path) + quant_source = dst_path + else: + dst_path = None + quant_source = tmp_path + + return dst_path, quant_source, fix_path + +def make_quant(src, output_dir, temp_dir, qtype, quantize_binary, fix_path=None, resume=True): + name, ext = os.path.splitext(os.path.basename(src)) + assert ext.lower() == ".gguf", "Invalid input file" + + src_qtext = [x for x in ["-F32.gguf", "-F16.gguf", "-BF16.gguf"] if x in src] + if len(src_qtext) == 1: + tmp_path = os.path.join( + temp_dir, + os.path.basename(src).replace(src_qtext[0], f"-{qtype.upper()}.gguf") + ) + else: + tmp_path = os.path.join( + temp_dir, + f"{name}-{qtype.upper()}.gguf" + ) + tmp_path = os.path.abspath(tmp_path) + dst_path = os.path.join(output_dir, os.path.basename(tmp_path)) + if os.path.isfile(dst_path): + if resume: + return dst_path + raise OSError("Output already exists! Clear folder?") + + r = run_cmd(quantize_binary, src, tmp_path, qtype, log_error=True) + time.sleep(2) # leave time for file sync? + if r != 0: + raise RuntimeError(f"Quantization failed with error code {r}") + + if fix_path is not None: + apply_5d_fix(tmp_path, dst_path, fix=fix_path, overwrite=False) + if os.path.isfile(dst_path) and os.path.isfile(tmp_path): + os.remove(tmp_path) + else: + os.rename(tmp_path, dst_path) + + return dst_path + +if __name__ == "__main__": + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.temp_dir, exist_ok=True) + quantize_binary = setup_utils(args.force_update) + + try: + from convert import detect_arch, strip_prefix, convert_file + from fix_5d_tensors import apply_5d_fix + except [ImportError, ModuleNotFoundError] as e: + raise ImportError(f"Can't import required utils: {e}") + + if not os.path.isfile(args.src): + # huggingface repo. TODO: file choice + if len(args.src.split("/")) != "1": + raise OSError(f"Invalid huggingface repo or model path {args.src}") + raise NotImplementedError("HF not yet supported") + # download then set to temp file + # hf_repo = "Lightricks/LTX-Video" # "fal/AuraFlow-v0.3" + # get_hf_valid_files(hf_repo) + # args.src = ... + + out_files = [] + + base_quant, quant_source, fix_path = make_base_quant( + args.src, + args.output_dir, + args.temp_dir, + final=("base" in args.quants), + resume=args.resume, + ) + if "base" in args.quants: + args.quants = [x for x in args.quants if x not in ["base"]] + if base_quant is not None: + out_files.append(base_quant) + + for qtype in args.quants: + out_files.append(make_quant( + quant_source, + args.output_dir, + args.temp_dir, + qtype, + quantize_binary, + fix_path, + resume=args.resume, + )) + + if fix_path is not None and os.path.isfile(fix_path): + os.remove(fix_path) + + if base_quant != quant_source: + # make sure our quant source is in the temp folder before removing it + cc = os.path.commonpath([os.path.normpath(quant_source), os.path.normpath(args.temp_dir)]) + if cc == os.path.normpath(args.temp_dir): + os.remove(quant_source) + + out_file_str = '\n'.join(out_files) + logging.info(f"Output file(s): {out_file_str}")