File size: 2,819 Bytes
ab58e1d 27b0a66 ab58e1d 27b0a66 ab58e1d 27b0a66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
---
base_model: black-forest-labs/FLUX.1-dev
library_name: diffusers
base_model_relation: quantized
tags:
- quantization
---
# Visual comparison of Flux-dev model outputs using BF16 and torchao float8_weight_only quantization
<td style="text-align: center;">
BF16<br>
<medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_bf16_combined.png" alt="Flux-dev output with BF16: Baroque, Futurist, Noir styles"></medium-zoom>
</td>
<td style="text-align: center;">
torchao fp8_weight_only<br>
<medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_torchao_fp8_combined.png" alt="torchao fp8_weight_only Output">
</td>
# Usage with Diffusers
To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library:
```
pip install -U diffusers
pip install -U torchao
```
After installing the required library, you can run the following script:
```python
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"diffusers/FLUX.1-dev-torchao-fp8",
torch_dtype=torch.bfloat16,
use_safetensors=False,
device_map="balanced"
)
prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase."
pipe_kwargs = {
"prompt": prompt,
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 50,
"max_sequence_length": 512,
}
image = pipe(
**pipe_kwargs, generator=torch.manual_seed(0),
).images[0]
image.save("flux.png")
```
# How to generate this quantized checkpoint ?
This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint:
```python
import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import TorchAoConfig as DiffusersTorchAoConfig
from transformers import TorchAoConfig as TransformersTorchAoConfig
from torchao.quantization import Float8WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusersTorchAoConfig("float8_weight_only"),
"text_encoder_2": TransformersTorchAoConfig(Float8WeightOnlyConfig()),
}
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="balanced"
)
# safe_serialization set to `False` as we can't save torchao quantized model to safetensors format
pipe.save_pretrained("FLUX.1-dev-torchao-fp8", safe_serialization=False)
``` |