Maxwell / README.md
ABDALLALSWAITI's picture
Update README.md
e793dee verified
metadata
license: creativeml-openrail-m
language:
  - en
base_model: black-forest-labs/FLUX.1-schnell
pipeline_tag: text-to-image
library_name: diffusers
tags:
  - bnb
  - nf4
  - flux

license: creativeml-openrail-m

This model may be used by individuals for personal and commercial purposes, including generating and selling images. Commercial use by companies or organizations is strictly prohibited.

Maxwell Model

Acknowledgements

Firstly, a big thanks to @sayakpaul who fixed most issues we were facing with Diffusers. i used his way of Quantization bnb-NF4

Installation

  1. Install the required packages:
pip install torch accelerate safetensors diffusers  huggingface_hub bitsandbytes transformers

Download convert_nf4_flux.py @same level of Generative Code

Usage

Run the following Python code:

# Generative Code
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

# Set dtype and check for float8 support
dtype = torch.bfloat16
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

# Download the model checkpoint
ckpt_path = hf_hub_download("ABDALLALSWAITI/Maxwell", filename="diffusion_pytorch_model.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)

# Initialize the model with empty weights
with init_empty_weights():
    config = FluxTransformer2DModel.load_config("ABDALLALSWAITI/Maxwell")
    model = FluxTransformer2DModel.from_config(config).to(dtype)
    expected_state_dict_keys = list(model.state_dict().keys())

# Replace layers with NF4 quantized versions
replace_with_bnb_linear(model, "nf4")

# Load the state dict into the quantized model
for param_name, param in original_state_dict.items():
    if param_name not in expected_state_dict_keys:
        continue
    
    is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
    if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
        param = param.to(dtype)
    
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(
            model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
        )

# Clean up
del original_state_dict
gc.collect()

# Print model size
print(compute_module_sizes(model)[""] / 1024 / 1204)

# Initialize the pipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

# Generate an image from a prompt
prompt = "A mystic Tiger play guitar   with sign that says hello world!"
image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, generator=torch.manual_seed(0)).images[0]
image.save("simple.png")

This code will download the Maxwell model, initialize it with NF4 quantization, and generate an image based on the given prompt.