Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from gradio_client import Client | |
from diffusers import AutoencoderKL, StableDiffusionXLPipeline | |
import torch | |
import concurrent.futures | |
import spaces | |
client_lightning = Client("AP123/SDXL-Lightning") | |
client_hyper = Client("ByteDance/Hyper-SDXL-1Step-T2I") | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
### SDXL Turbo #### | |
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", | |
vae=vae, | |
torch_dtype=torch.float16, | |
variant="fp16" | |
) | |
pipe_turbo.to("cuda") | |
def get_lighting_result(prompt): | |
result_lighting = client_lightning.predict( | |
prompt, # Your prompt | |
"1-Step", # Number of inference steps | |
api_name="/generate_image" | |
) | |
return result_lighting | |
def get_hyper_result(prompt): | |
result_hyper = client_hyper.predict( | |
num_images=1, | |
height=1024, | |
width=1024, | |
prompt=prompt, | |
seed=3413, | |
api_name="/process_image" | |
) | |
return result_hyper | |
def get_turbo_result(prompt): | |
image_turbo = pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0] | |
return image_turbo | |
def run_comparison(prompt): | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
# Submit tasks to the executor | |
future_lighting = executor.submit(get_lighting_result, prompt) | |
future_hyper = executor.submit(get_hyper_result, prompt) | |
future_turbo = executor.submit(get_turbo_result, prompt) | |
# Wait for all futures to complete | |
results = concurrent.futures.wait( | |
[future_lighting, future_hyper, future_turbo], | |
return_when=concurrent.futures.ALL_COMPLETED | |
) | |
# Extract results from futures | |
result_lighting = future_lighting.result() | |
result_hyper = future_hyper.result() | |
image_turbo = future_turbo.result() | |
print(result_lighting) | |
print(result_hyper) | |
return image_turbo, result_lighting, result_hyper | |
css = ''' | |
.gradio-container{max-width: 768px !important} | |
''' | |
with gr.Blocks(css=css) as demo: | |
prompt = gr.Textbox(label="Prompt") | |
run = gr.Button("Run") | |
with gr.Row(): | |
image_turbo = gr.Image(label="SDXL Turbo") | |
image_lightning = gr.Image(label="SDXL Lightning") | |
image_hyper = gr.Image("Hyper SDXL") | |
run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper]) | |
demo.launch() |