|
|
|
|
|
import os
|
|
|
import gguf
|
|
|
import torch
|
|
|
import argparse
|
|
|
from tqdm import tqdm
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
def get_args():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--src", required=True)
|
|
|
parser.add_argument("--dst", required=True)
|
|
|
parser.add_argument("--fix", required=False, help="Defaults to ./fix_5d_tensors_[arch].pt")
|
|
|
parser.add_argument("--overwrite", action="store_true")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if not os.path.isfile(args.src):
|
|
|
parser.error(f"Invalid source file '{args.src}'")
|
|
|
if not args.overwrite and os.path.exists(args.dst):
|
|
|
parser.error(f"Output exists, use '--overwrite' ({args.dst})")
|
|
|
|
|
|
return args
|
|
|
|
|
|
def get_arch_str(reader):
|
|
|
field = reader.get_field("general.architecture")
|
|
|
return str(field.parts[field.data[-1]], encoding="utf-8")
|
|
|
|
|
|
def get_file_type(reader):
|
|
|
field = reader.get_field("general.file_type")
|
|
|
ft = int(field.parts[field.data[-1]])
|
|
|
return gguf.LlamaFileType(ft)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
args = get_args()
|
|
|
|
|
|
|
|
|
reader = gguf.GGUFReader(args.src)
|
|
|
arch = get_arch_str(reader)
|
|
|
file_type = get_file_type(reader)
|
|
|
print(f"Detected arch: '{arch}' (ftype: {str(file_type)})")
|
|
|
|
|
|
|
|
|
if args.fix is None:
|
|
|
args.fix = f"./fix_5d_tensors_{arch}.safetensors"
|
|
|
|
|
|
if not os.path.isfile(args.fix):
|
|
|
raise OSError(f"No 5D tensor fix file: {args.fix}")
|
|
|
|
|
|
sd5d = load_file(args.fix)
|
|
|
sd5d = {k:v.numpy() for k,v in sd5d.items()}
|
|
|
print("5D tensors:", sd5d.keys())
|
|
|
|
|
|
|
|
|
writer = gguf.GGUFWriter(path=None, arch=arch)
|
|
|
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
|
|
writer.add_file_type(file_type)
|
|
|
|
|
|
added = []
|
|
|
def add_extra_key(writer, key, data):
|
|
|
global added
|
|
|
data_qtype = gguf.GGMLQuantizationType.F32
|
|
|
data = gguf.quants.quantize(data, data_qtype)
|
|
|
tqdm.write(f"Adding key {key} ({data.shape})")
|
|
|
writer.add_tensor(key, data, raw_dtype=data_qtype)
|
|
|
added.append(key)
|
|
|
|
|
|
|
|
|
for tensor in tqdm(reader.tensors):
|
|
|
writer.add_tensor(tensor.name, tensor.data, raw_dtype=tensor.tensor_type)
|
|
|
key5d = tensor.name.replace(".bias", ".weight")
|
|
|
if key5d in sd5d.keys():
|
|
|
add_extra_key(writer, key5d, sd5d[key5d])
|
|
|
|
|
|
|
|
|
for key, data in sd5d.items():
|
|
|
if key not in added:
|
|
|
add_extra_key(writer, key, data)
|
|
|
|
|
|
writer.write_header_to_file(path=args.dst)
|
|
|
writer.write_kv_data_to_file()
|
|
|
writer.write_tensors_to_file(progress=True)
|
|
|
writer.close()
|
|
|
|