File size: 2,819 Bytes
ab58e1d
27b0a66
ab58e1d
27b0a66
 
 
ab58e1d
27b0a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
---
base_model: black-forest-labs/FLUX.1-dev
library_name: diffusers
base_model_relation: quantized
tags:
- quantization
---
# Visual comparison of Flux-dev model outputs using BF16 and torchao float8_weight_only quantization

<td style="text-align: center;">
  BF16<br>
  <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_bf16_combined.png" alt="Flux-dev output with BF16: Baroque, Futurist, Noir styles"></medium-zoom>
</td>
<td style="text-align: center;">
  torchao fp8_weight_only<br>
  <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_torchao_fp8_combined.png" alt="torchao fp8_weight_only Output">
  </td>

# Usage with Diffusers

To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library:

```
pip install -U diffusers
pip install -U torchao
```

After installing the required library, you can run the following script: 

```python
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
    "diffusers/FLUX.1-dev-torchao-fp8",
    torch_dtype=torch.bfloat16,
    use_safetensors=False,
    device_map="balanced"
)
prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase."
pipe_kwargs = {
    "prompt": prompt,
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}
image = pipe(
    **pipe_kwargs, generator=torch.manual_seed(0),
).images[0]
image.save("flux.png")
```

# How to generate this quantized checkpoint ? 

This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint:

```python
import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import TorchAoConfig as DiffusersTorchAoConfig
from transformers import TorchAoConfig as TransformersTorchAoConfig

from torchao.quantization import Float8WeightOnlyConfig

pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={
        "transformer": DiffusersTorchAoConfig("float8_weight_only"),
        "text_encoder_2": TransformersTorchAoConfig(Float8WeightOnlyConfig()),
    }
)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=pipeline_quant_config,
    torch_dtype=torch.bfloat16,
    device_map="balanced"
)
# safe_serialization set to `False` as we can't save torchao quantized model to safetensors format
pipe.save_pretrained("FLUX.1-dev-torchao-fp8", safe_serialization=False)
```