Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import spaces | |
from torchao.quantization import autoquant | |
from diffusers import FluxPipeline | |
pipe = FluxPipeline.from_pretrained( | |
"sayakpaul/FLUX.1-merged", | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
pipe.transformer.to(memory_format=torch.channels_last) | |
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) | |
pipe.transformer = autoquant( | |
pipe.transformer, | |
error_on_unseen=False | |
) | |
def generate_images(prompt, guidance_scale, num_inference_steps): | |
# # generate image with normal pipeline | |
# image_normal = pipeline_normal( | |
# prompt=prompt, | |
# guidance_scale=guidance_scale, | |
# num_inference_steps=int(num_inference_steps) | |
# ).images[0] | |
# generate image with optimized pipeline | |
image_optimized = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=int(num_inference_steps) | |
).images[0] | |
return image_optimized | |
# set up Gradio interface | |
demo = gr.Interface( | |
fn=generate_images, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), | |
gr.Slider(1.0, 10.0, step=0.5, value=3.5, label="Guidance Scale"), | |
gr.Slider(10, 100, step=1, value=50, label="Number of Inference Steps") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Optimized FluxPipeline") | |
], | |
title="FluxPipeline Comparison", | |
description="Compare images generated by the normal FluxPipeline and the optimized one using torchao and torch.compile()." | |
) | |
demo.launch() | |