Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import spaces | |
from diffusers import FluxPipeline | |
from torchao.quantization import autoquant | |
# normal FluxPipeline | |
pipeline_normal = FluxPipeline.from_pretrained( | |
"sayakpaul/FLUX.1-merged", | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
# # optimized FluxPipeline | |
# pipeline_optimized = DiffusionPipeline.from_pretrained( | |
# "sayakpaul/FLUX.1-merged", | |
# torch_dtype=torch.bfloat16 | |
# ).to("cuda") | |
# pipeline_optimized.transformer.to(memory_format=torch.channels_last) | |
# pipeline_optimized.transformer = torch.compile( | |
# pipeline_optimized.transformer, | |
# mode="max-autotune", | |
# fullgraph=True | |
# ) | |
# pipeline_optimized.transformer = autoquant( | |
# pipeline_optimized.transformer, | |
# error_on_unseen=False | |
# ) | |
def generate_images(prompt, guidance_scale, num_inference_steps): | |
# generate image with normal pipeline | |
image_normal = pipeline_normal( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=int(num_inference_steps) | |
).images[0] | |
# # generate image with optimized pipeline | |
# image_optimized = pipeline_optimized( | |
# prompt=prompt, | |
# guidance_scale=guidance_scale, | |
# num_inference_steps=int(num_inference_steps) | |
# ).images[0] | |
# return image_normal, image_optimized | |
return image_normal | |
# set up Gradio interface | |
demo = gr.Interface( | |
fn=generate_images, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), | |
gr.Slider(1.0, 10.0, step=0.5, value=3.5, label="Guidance Scale"), | |
gr.Slider(10, 100, step=1, value=50, label="Number of Inference Steps") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Normal FluxPipeline"), | |
# gr.Image(type="pil", label="Optimized FluxPipeline") | |
], | |
title="FluxPipeline Comparison", | |
description="Compare images generated by the normal FluxPipeline and the optimized one using torchao and torch.compile()." | |
) | |
demo.launch() | |