File size: 8,946 Bytes
455acc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import os
import torch
import numpy as np
import gguf # This needs to be the llama.cpp one specifically!
import argparse
from tqdm import tqdm
from safetensors.torch import load_file
QUANTIZATION_THRESHOLD = 1024
REARRANGE_THRESHOLD = 512
MAX_TENSOR_NAME_LENGTH = 127
class QuantError(Exception):
pass
class quants:
@staticmethod
def quantize(data, data_qtype):
# Implement quantization logic here
if data_qtype == GGMLQuantizationType.F16:
return data.astype(np.float16)
elif data_qtype == GGMLQuantizationType.BF16:
return data.astype(np.float32) # BF16 is not supported by NumPy, so use float32 instead
else:
raise QuantError("Unsupported quantization type")
class ModelTemplate:
arch = "invalid" # string describing architecture
shape_fix = False # whether to reshape tensors
keys_detect = [] # list of lists to match in state dict
keys_banned = [] # list of keys that should mark model as invalid for conversion
class ModelFlux(ModelTemplate):
arch = "fluxz"
keys_detect = [
("transformer_blocks.0.attn.norm_added_k.weight",),
("double_blocks.0.img_attn.proj.weight",),
]
keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",]
class ModelSD3(ModelTemplate):
arch = "sd3"
keys_detect = [
("transformer_blocks.0.attn.add_q_proj.weight",),
("joint_blocks.0.x_block.attn.qkv.weight",),
]
keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",]
class ModelSDXL(ModelTemplate):
arch = "sdxl"
shape_fix = True
keys_detect = [
("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
"output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
), # Non-diffusers
("label_emb.0.0.weight",),
]
class ModelSD1(ModelTemplate):
arch = "sd1"
shape_fix = False
keys_detect = [
("down_blocks.0.downsamplers.0.conv.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
"output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
), # Non-diffusers
]
# Prioritize ModelSD3 over ModelFlux
arch_list = [ModelSD3, ModelFlux, ModelSDXL, ModelSD1]
def is_model_arch(model, state_dict):
# check if model is correct
matched = False
invalid = False
for match_list in model.keys_detect:
if all(key in state_dict for key in match_list):
matched = True
invalid = any(key in state_dict for key in model.keys_banned)
break
assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
return matched
def detect_arch(state_dict):
model_arch = None
for arch in arch_list:
if is_model_arch(arch, state_dict):
model_arch = arch
break
assert model_arch is not None, "Unknown model architecture!"
return model_arch
def parse_args():
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
parser.add_argument("--src", required=True, help="Source model ckpt file.")
parser.add_argument("--dst", help="Output unet gguf file.")
args = parser.parse_args()
if not os.path.isfile(args.src):
parser.error("No input provided!")
return args
def load_state_dict(path):
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
state_dict = state_dict.get("model", state_dict)
else:
state_dict = load_file(path)
# only keep unet with no prefix!
sd = {}
has_prefix = any(["model.diffusion_model." in x for x in state_dict.keys()])
for k, v in state_dict.items():
if has_prefix and "model.diffusion_model." not in k:
continue
if has_prefix:
k = k.replace("model.diffusion_model.", "")
sd[k] = v
return sd
def load_model(path):
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
print(f"* Architecture detected from input: {model_arch.arch}")
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
return (writer, state_dict, model_arch)
def handle_tensors(args, writer, state_dict, model_arch):
name_lengths = tuple(sorted(
((key, len(key)) for key in state_dict.keys()),
key=lambda item: item[1],
reverse=True,
))
if not name_lengths:
return
max_name_len = name_lengths[0][1]
if max_name_len > MAX_TENSOR_NAME_LENGTH:
bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
for key, data in tqdm(state_dict.items()):
old_dtype = data.dtype
if data.dtype == torch.bfloat16:
data = data.to(torch.float32).numpy()
# this is so we don't break torch 2.0.X
elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
data = data.to(torch.float16).numpy()
else:
data = data.numpy()
n_dims = len(data.shape)
data_shape = data.shape
data_qtype = getattr(
gguf.GGMLQuantizationType,
"BF16" if old_dtype == torch.bfloat16 else "F16"
)
# get number of parameters (AKA elements) in this tensor
n_params = 1
for dim_size in data_shape:
n_params *= dim_size
# keys to keep as max precision
blacklist = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
}
if old_dtype in (torch.float32, torch.bfloat16):
if n_dims == 1:
# one-dimensional tensors should be kept in F32
# also speeds up inference due to not dequantizing
data_qtype = gguf.GGMLQuantizationType.F32
elif n_params <= QUANTIZATION_THRESHOLD:
# very small tensors
data_qtype = gguf.GGMLQuantizationType.F32
elif ".weight" in key and any(x in key for x in blacklist):
data_qtype = gguf.GGMLQuantizationType.F32
if (model_arch.shape_fix # NEVER reshape for models such as flux
and n_dims > 1 # Skip one-dimensional tensors
and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
):
orig_shape = data.shape
data = data.reshape(n_params // 256, 256)
writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
new_name = key # do we need to rename?
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def load_model(path):
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
print(f"* Architecture detected from input: {model_arch.arch}")
return state_dict, model_arch
...
if __name__ == "__main__":
args = parse_args()
path = args.src
state_dict, model_arch = load_model(path)
if next(iter(state_dict.values())).dtype == torch.bfloat16:
out_path = f"{os.path.splitext(path)[0]}-BF16.gguf"
else:
out_path = f"{os.path.splitext(path)[0]}-F16.gguf"
out_path = args.dst or out_path
if os.path.isfile(out_path):
input("Output exists enter to continue or ctrl+c to abort!")
writer = gguf.GGUFWriter(path=out_path, arch=model_arch.arch)
writer.add_quantization_version(1)
handle_tensors(args, writer, state_dict, model_arch)
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
|