quite slow to load the fp8 model

#21
by gpt3eth - opened

On Nvidia A6000, using the code below to load the fp8

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize
import time
import json

# Initialize a dictionary to store stats
stats = {}

# Measure the time taken to load and prepare the model into VRAM
start_time = time.time()
bfl_repo = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file(
    "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", 
    torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.to("cuda")
stats['model_loading_time'] = time.time() - start_time

it took around 293s to load the model, why is this so slow to load?

I'm getting the error when running this code:

transformer = FluxTransformer2DModel.from_single_file(...

AttributeError: type object 'FluxTransformer2DModel' has no attribute 'from_single_file'

Which version of diffuser do you have installed?

Thanks!

Thanks, I think they just realeased the .from_single_file

Could make it work installing the latest version from the main (I basically just clicked through all steps listed https://huggingface.co/docs/diffusers/installation#install-from-source

885s with a simple laptop 16GB ram (no GPU support since Windows+AMD-GPU)
The quantizing + freezing takes a lot of time.

why have u quantized it again when it is already in fp8?

You could probably just import it like so:

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)

Sign up or log in to comment