import gradio as gr from gradio_client import Client from diffusers import AutoencoderKL, StableDiffusionXLPipeline import torch import concurrent.futures import spaces client_lightning = Client("AP123/SDXL-Lightning") client_hyper = Client("ByteDance/Hyper-SDXL-1Step-T2I") vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) ### SDXL Turbo #### pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo", vae=vae, torch_dtype=torch.float16, variant="fp16" ) pipe_turbo.to("cuda") def get_lighting_result(prompt): result_lighting = client_lightning.predict( prompt, # Your prompt "1-Step", # Number of inference steps api_name="/generate_image" ) return result_lighting def get_hyper_result(prompt): result_hyper = client_hyper.predict( num_images=1, height=1024, width=1024, prompt=prompt, seed=3413, api_name="/process_image" ) return result_hyper @spaces.GPU def get_turbo_result(prompt): image_turbo = pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0] return image_turbo def run_comparison(prompt): with concurrent.futures.ThreadPoolExecutor() as executor: # Submit tasks to the executor future_lighting = executor.submit(get_lighting_result, prompt) future_hyper = executor.submit(get_hyper_result, prompt) future_turbo = executor.submit(get_turbo_result, prompt) # Wait for all futures to complete results = concurrent.futures.wait( [future_lighting, future_hyper, future_turbo], return_when=concurrent.futures.ALL_COMPLETED ) # Extract results from futures result_lighting = future_lighting.result() result_hyper = future_hyper.result() image_turbo = future_turbo.result() print(result_lighting) print(result_hyper) return image_turbo, result_lighting, result_hyper css = ''' .gradio-container{max-width: 768px !important} ''' with gr.Blocks(css=css) as demo: prompt = gr.Textbox(label="Prompt") run = gr.Button("Run") with gr.Row(): image_turbo = gr.Image(label="SDXL Turbo") image_lightning = gr.Image(label="SDXL Lightning") image_hyper = gr.Image("Hyper SDXL") run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper]) demo.launch()