multimodalart HF staff commited on
Commit
91dd651
1 Parent(s): de56cd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
3
- import torch
4
- from huggingface_hub import hf_hub_download
5
- from safetensors.torch import load_file
6
- import gc
7
- import spaces
8
 
 
 
9
 
10
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
11
 
@@ -15,64 +13,57 @@ pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
15
  torch_dtype=torch.float16,
16
  variant="fp16"
17
  )
18
- #pipe_turbo.to("cuda")
19
 
20
- ### SDXL Lightning ###
21
- base = "stabilityai/stable-diffusion-xl-base-1.0"
22
- repo = "ByteDance/SDXL-Lightning"
23
- ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
24
 
25
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
26
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
27
- pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
28
- unet=unet,
29
- vae=vae,
30
- text_encoder=pipe_turbo.text_encoder,
31
- text_encoder_2=pipe_turbo.text_encoder_2,
32
- tokenizer=pipe_turbo.tokenizer,
33
- tokenizer_2=pipe_turbo.tokenizer_2,
34
- torch_dtype=torch.float16,
35
- variant="fp16"
36
- )#.to("cuda")
37
- del unet
38
- gc.collect()
39
- pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
40
- #pipe_lightning.to("cuda")
41
 
42
- ### Hyper SDXL ###
43
- repo_name = "ByteDance/Hyper-SD"
44
- ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
45
-
46
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
47
- unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
48
- pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
49
- unet=unet,
50
- vae=vae,
51
- text_encoder=pipe_turbo.text_encoder,
52
- text_encoder_2=pipe_turbo.text_encoder_2,
53
- tokenizer=pipe_turbo.tokenizer,
54
- tokenizer_2=pipe_turbo.tokenizer_2,
55
- torch_dtype=torch.float16,
56
- variant="fp16"
57
- )#.to("cuda")
58
- pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
59
- #pipe_hyper.to("cuda")
60
- del unet
61
- gc.collect()
62
 
63
  @spaces.GPU
64
- def run_comparison(prompt):
65
- image_turbo.to("cuda")
66
- image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
67
- image_turbo.to("cpu")
68
- image_lightning.to("cuda")
69
- image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
70
- image_lightning.to("cpu")
71
- image_hyper.to("cuda")
72
- image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
73
- image_turbo.to("cpu")
74
- return image_turbo, image_lightning, image_hyper
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
76
 
77
  css = '''
78
  .gradio-container{max-width: 768px !important}
 
1
  import gradio as gr
2
+ from gradio_client import Client
3
+ import concurrent.futures
 
 
 
 
4
 
5
+ client_lightning = Client("AP123/SDXL-Lightning")
6
+ client_hyper = Client("ByteDance/Hyper-SDXL-1Step-T2I")
7
 
8
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
9
 
 
13
  torch_dtype=torch.float16,
14
  variant="fp16"
15
  )
16
+ pipe_turbo.to("cuda")
17
 
 
 
 
 
18
 
19
+ def get_lighting_result(prompt):
20
+ result_lighting = client_lightning.predict(
21
+ prompt, # Your prompt
22
+ "1-Step", # Number of inference steps
23
+ api_name="/generate_image"
24
+ )
25
+ return result_lighting
 
 
 
 
 
 
 
 
 
26
 
27
+ def get_hyper_result(prompt):
28
+ result_hyper = client_hyper.predict(
29
+ num_images=1,
30
+ height=1024,
31
+ width=1024,
32
+ prompt=prompt,
33
+ seed=3413,
34
+ api_name="/process_image"
35
+ )
36
+ return result_hyper
 
 
 
 
 
 
 
 
 
 
37
 
38
  @spaces.GPU
39
+ def get_turbo_result(prompt):
40
+ image_turbo = pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
41
+ return image_turbo
42
+
43
+ def run_in_parallel(prompt):
44
+ with concurrent.futures.ThreadPoolExecutor() as executor:
45
+ # Submit tasks to the executor
46
+ future_lighting = executor.submit(get_lighting_result, prompt)
47
+ future_hyper = executor.submit(get_hyper_result, prompt)
48
+ future_turbo = executor.submit(get_turbo_result, prompt)
49
+
50
+ # Wait for all futures to complete
51
+ results = concurrent.futures.wait(
52
+ [future_lighting, future_hyper, future_turbo],
53
+ return_when=concurrent.futures.ALL_COMPLETED
54
+ )
55
+
56
+ # Extract results from futures
57
+ result_lighting = future_lighting.result()
58
+ result_hyper = future_hyper.result()
59
+ image_turbo = future_turbo.result()
60
+ print(result_lighting)
61
+ print(result_hyper)
62
+ return image_turbo, result_lighting, result_hyper
63
 
64
+ # Example usage
65
+ prompt = "Enter your prompt here"
66
+ image_turbo, result_lighting, result_hyper = run_in_parallel(prompt)
67
 
68
  css = '''
69
  .gradio-container{max-width: 768px !important}