⚡ Flash Diffusion: FlashPixart ⚡

Flash Diffusion is a diffusion distillation method proposed in Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation by Clément Chadebec, Onur Tasar, Eyal Benaroche, and Benjamin Aubin from Jasper Research. This model is a 66.5M LoRA distilled version of Pixart-α model that is able to generate 1024x1024 images in 4 steps. See our live demo and official Github repo.

How to use?

The model can be used using the PixArtAlphaPipeline from diffusers library directly. It can allow reducing the number of required sampling steps to 4 steps.

import torch
from diffusers import PixArtAlphaPipeline, Transformer2DModel, LCMScheduler
from peft import PeftModel

# Load LoRA
transformer = Transformer2DModel.from_pretrained(
  "PixArt-alpha/PixArt-XL-2-1024-MS",
  subfolder="transformer",
  torch_dtype=torch.float16
)
transformer = PeftModel.from_pretrained(
  transformer,
  "jasperai/flash-pixart"
)

# Pipeline
pipe = PixArtAlphaPipeline.from_pretrained(
  "PixArt-alpha/PixArt-XL-2-1024-MS",
  transformer=transformer,
  torch_dtype=torch.float16
)

# Scheduler
pipe.scheduler = LCMScheduler.from_pretrained(
  "PixArt-alpha/PixArt-XL-2-1024-MS",
  subfolder="scheduler",
  timestep_spacing="trailing",
)

pipe.to("cuda")

prompt = "A raccoon reading a book in a lush forest."

image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]

Training Details

The model was trained for 40k iterations on 4 H100 GPUs (representing approximately 188 hours of training). Please refer to the paper for further parameters details.

Metrics on COCO 2014 validation (Table 4)

  • FID-10k: 29.30 (4 NFE)
  • CLIP Score: 0.303 (4 NFE)

Citation

If you find this work useful or use it in your research, please consider citing us

@misc{chadebec2024flash,
      title={Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation}, 
      author={Clement Chadebec and Onur Tasar and Eyal Benaroche and Benjamin Aubin},
      year={2024},
      eprint={2406.02347},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

This model is released under the the Creative Commons BY-NC license.

Downloads last month
4
Inference Examples
Inference API (serverless) has been turned off for this model.

Model tree for youknownothing/flash-pixart-lora

Adapter
(2)
this model