abreza commited on
Commit
378355d
·
1 Parent(s): 661bc0f

add SDXL-Lightning

Browse files
Files changed (1) hide show
  1. app.py +30 -31
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import os
2
  import shutil
3
  import tempfile
4
- import time
5
- from os import path
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import rembg
10
  import spaces
11
  import torch
12
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline, LCMScheduler
13
  from einops import rearrange
14
  from huggingface_hub import hf_hub_download
15
  from omegaconf import OmegaConf
16
  from PIL import Image
17
  from pytorch_lightning import seed_everything
18
- from safetensors.torch import load_file
19
  from torchvision.transforms import v2
20
- from tqdm import tqdm
21
 
22
  from src.utils.camera_util import (FOV_to_intrinsics, get_circular_camera_poses,
23
  get_zero123plus_input_cameras)
@@ -25,7 +22,6 @@ from src.utils.infer_util import (remove_background, resize_foreground)
25
  from src.utils.mesh_util import save_glb, save_obj
26
  from src.utils.train_util import instantiate_from_config
27
 
28
- torch.backends.cuda.matmul.allow_tf32 = True
29
 
30
  def find_cuda():
31
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
@@ -130,18 +126,18 @@ def make3d(images):
130
 
131
 
132
  @spaces.GPU
133
- def process_image(num_images, prompt):
134
- global pipe
135
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
136
- return pipe(
137
- prompt=[prompt]*num_images,
138
- generator=torch.Generator().manual_seed(123),
139
- num_inference_steps=1,
140
- guidance_scale=0.,
141
- height=int(512),
142
- width=int(512),
143
- timesteps=[800]
144
- ).images
145
 
146
 
147
  # Configuration
@@ -185,23 +181,24 @@ model.load_state_dict(state_dict, strict=True)
185
 
186
  model = model.to(device)
187
 
188
- # # Load text-to-image model
189
- # print('Loading text-to-image model ...')
 
190
 
191
- # pipe = StableDiffusionXLPipeline.from_pretrained(
192
- # "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16)
193
- # pipe.to(device="cuda", dtype=torch.bfloat16)
194
-
195
- # unet_state = load_file(hf_hub_download(
196
- # "ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet.safetensors"), device="cuda")
197
- # pipe.unet.load_state_dict(unet_state)
198
- # pipe.scheduler = LCMScheduler.from_config(
199
- # pipe.scheduler.config, timestep_spacing="trailing")
200
 
201
  print('Loading Finished!')
202
 
203
  # Gradio UI
204
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
205
  with gr.Row(variant="panel"):
206
  with gr.Column():
207
  with gr.Row():
@@ -228,7 +225,7 @@ with gr.Blocks() as demo:
228
  label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
229
 
230
  with gr.Row():
231
- submit = gr.Button(
232
  "Generate", elem_id="generate", variant="primary")
233
 
234
  with gr.Row(variant="panel"):
@@ -273,7 +270,9 @@ with gr.Blocks() as demo:
273
 
274
  mv_images = gr.State()
275
 
276
- submit.click(fn=check_input_image, inputs=[input_image]).success(
 
 
277
  fn=preprocess,
278
  inputs=[input_image, do_remove_background],
279
  outputs=[processed_image],
 
1
  import os
2
  import shutil
3
  import tempfile
 
 
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import rembg
8
  import spaces
9
  import torch
10
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline, EulerDiscreteScheduler
11
  from einops import rearrange
12
  from huggingface_hub import hf_hub_download
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
  from pytorch_lightning import seed_everything
 
16
  from torchvision.transforms import v2
17
+ from safetensors.torch import load_file
18
 
19
  from src.utils.camera_util import (FOV_to_intrinsics, get_circular_camera_poses,
20
  get_zero123plus_input_cameras)
 
22
  from src.utils.mesh_util import save_glb, save_obj
23
  from src.utils.train_util import instantiate_from_config
24
 
 
25
 
26
  def find_cuda():
27
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
 
126
 
127
 
128
  @spaces.GPU
129
+ def generate_image(prompt):
130
+ checkpoint = "sdxl_lightning_8step_unet.safetensors"
131
+ num_inference_steps = 8
132
+
133
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
134
+ pipe.scheduler.config, timestep_spacing="trailing")
135
+ pipe.unet.load_state_dict(
136
+ load_file(hf_hub_download(repo, checkpoint), device="cuda"))
137
+
138
+ results = pipe(
139
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
140
+ return results.images[0]
141
 
142
 
143
  # Configuration
 
181
 
182
  model = model.to(device)
183
 
184
+ # Load StableDiffusionXL model
185
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
186
+ repo = "ByteDance/SDXL-Lightning"
187
 
188
+ pipe = StableDiffusionXLPipeline.from_pretrained(
189
+ base, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
 
 
 
 
 
 
190
 
191
  print('Loading Finished!')
192
 
193
  # Gradio UI
194
  with gr.Blocks() as demo:
195
+ with gr.Group():
196
+ with gr.Row():
197
+ prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
198
+ submit_prompt = gr.Button(
199
+ scale=1, variant='primary', label='Generate Image')
200
+ img = gr.Image(label='SDXL-Lightning Generated Image')
201
+
202
  with gr.Row(variant="panel"):
203
  with gr.Column():
204
  with gr.Row():
 
225
  label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
226
 
227
  with gr.Row():
228
+ submit_mesh = gr.Button(
229
  "Generate", elem_id="generate", variant="primary")
230
 
231
  with gr.Row(variant="panel"):
 
270
 
271
  mv_images = gr.State()
272
 
273
+ submit_prompt.click(fn=generate_image, inputs=[prompt], outputs=img)
274
+
275
+ submit_mesh.click(fn=check_input_image, inputs=[input_image]).success(
276
  fn=preprocess,
277
  inputs=[input_image, do_remove_background],
278
  outputs=[processed_image],