An error occurred during loading.

#1
by democrash - opened
ValueError: Cannot load /mnt/data/yinwenbin/flux.1-schnell-nf4 because context_embedder.weight expected shape tensor(..., device='meta', size=(3072, 4096)), but got torch.Size([6291456, 1]). If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example.

I encountered this problem. Below is the code I used.

transformer = FluxTransformer2DModel.from_pretrained("flux.1-schnell-nf4", torch_dtype=torch.bfloat16)
Owner

The quantization is adapted from https://github.com/huggingface/diffusers/issues/9165 . Use something like this to run it:

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-schnell", filename="flux1-schnell.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict
gc.collect()

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-schnell", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
gc.collect()

print(compute_module_sizes(model)[""] / 1024 / 1204)

pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-schnell", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=4, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-schnell.png")

model.push_to_hub("skimai/flux.1-schnell-nf4")

Sign up or log in to comment