Quantization Method?

#7
by vyralsurfer - opened

Hello! First off, thank you for providing these; the original models barely loaded on my system since I have low system RAM.

My question for Kijai is how were these models quantized? Just looking to learn something new and also prepare for future models. Thank you!

Owner

It's nothing too complicated, just loading the weights, casting to fp8 and saving. Initially I made the mistake of not including the metadata, which doesn't affect the functionality, but some interfaces can show it and it's useful, not to mention it's a way to include the license and credits.

from safetensors.torch import load_file, save_file
import torch
import json

path = "flux1-dev.sft" # input file

# read safetensors metadata
def read_safetensors_metadata(path):
    with open(path, 'rb') as f:       
        header_size = int.from_bytes(f.read(8), 'little')
        header_json = f.read(header_size).decode('utf-8')
        header = json.loads(header_json)
        metadata = header.get('__metadata__', {})
        return metadata

metadata = read_safetensors_metadata(path)
print(json.dumps(metadata, indent=4)) #show metadata

sd_pruned = dict() #initialize empty dict

state_dict = load_file(path) #load safetensors file
for key in state_dict: #for each key in the safetensors file
    sd_pruned[key] = state_dict[key].to(torch.float8_e4m3fn) #convert to fp8

# save the pruned safetensors file
save_file(sd_pruned, "flux1-dev-fp8.safetensors", metadata={"format": "pt", **metadata})

thanks your work,Is this model suitable for quantization to int4?

Owner

I don't think there's a way to do inference with int4 in pytorch at least. From what I understood image models in general would lose too much from that too.

Thanks for your answer.

why not try GPTQ for better quality?

Perhaps using a calibration set to calibrate after quantization like LLM quantization can reduce the loss caused by quantization. After all, the core of Dit is also transformers.

I combined this code with kohya-ss's mem_eff_save_file.py to create a memory-efficient fp8 conversion script. Thanks for sharing the code!

Sign up or log in to comment