abreza commited on
Commit
b27b04c
β€’
1 Parent(s): 47b1e0f

remove sdxl

Browse files
Files changed (1) hide show
  1. app.py +28 -86
app.py CHANGED
@@ -1,21 +1,18 @@
1
  import os
2
  import shutil
3
  import tempfile
4
- import time
5
- from os import path
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import rembg
10
  import spaces
11
  import torch
12
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline, LCMScheduler
13
  from einops import rearrange
14
  from huggingface_hub import hf_hub_download
15
  from omegaconf import OmegaConf
16
  from PIL import Image
17
  from pytorch_lightning import seed_everything
18
- from safetensors.torch import load_file
19
  from torchvision.transforms import v2
20
  from tqdm import tqdm
21
 
@@ -25,26 +22,6 @@ from src.utils.infer_util import (remove_background, resize_foreground)
25
  from src.utils.mesh_util import save_glb, save_obj
26
  from src.utils.train_util import instantiate_from_config
27
 
28
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
29
- os.environ["TRANSFORMERS_CACHE"] = cache_path
30
- os.environ["HF_HUB_CACHE"] = cache_path
31
- os.environ["HF_HOME"] = cache_path
32
-
33
- torch.backends.cuda.matmul.allow_tf32 = True
34
-
35
-
36
- class timer:
37
- def __init__(self, method_name="timed process"):
38
- self.method = method_name
39
-
40
- def __enter__(self):
41
- self.start = time.time()
42
- print(f"{self.method} starts")
43
-
44
- def __exit__(self, exc_type, exc_val, exc_tb):
45
- end = time.time()
46
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
47
-
48
 
49
  def find_cuda():
50
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
@@ -75,7 +52,7 @@ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexi
75
 
76
  def check_input_image(input_image):
77
  if input_image is None:
78
- raise gr.Error("No image selected!")
79
 
80
 
81
  def preprocess(input_image, do_remove_background):
@@ -148,21 +125,6 @@ def make3d(images):
148
  return mesh_fpath, mesh_glb_fpath
149
 
150
 
151
- @spaces.GPU
152
- def process_image(num_images, prompt):
153
- global pipe
154
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
155
- return pipe(
156
- prompt=[prompt]*num_images,
157
- generator=torch.Generator().manual_seed(123),
158
- num_inference_steps=1,
159
- guidance_scale=0.,
160
- height=int(512),
161
- width=int(512),
162
- timesteps=[800]
163
- ).images
164
-
165
-
166
  # Configuration
167
  cuda_path = find_cuda()
168
  config_path = 'configs/instant-mesh-large.yaml'
@@ -204,21 +166,6 @@ model.load_state_dict(state_dict, strict=True)
204
 
205
  model = model.to(device)
206
 
207
- # Load text-to-image model
208
- print('Loading text-to-image model ...')
209
- if not path.exists(cache_path):
210
- os.makedirs(cache_path, exist_ok=True)
211
-
212
- pipe = StableDiffusionXLPipeline.from_pretrained(
213
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16)
214
- pipe.to(device="cuda", dtype=torch.bfloat16)
215
-
216
- unet_state = load_file(hf_hub_download(
217
- "ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet.safetensors"), device="cuda")
218
- pipe.unet.load_state_dict(unet_state)
219
- pipe.scheduler = LCMScheduler.from_config(
220
- pipe.scheduler.config, timestep_spacing="trailing")
221
-
222
  print('Loading Finished!')
223
 
224
  # Gradio UI
@@ -226,23 +173,19 @@ with gr.Blocks() as demo:
226
  with gr.Row(variant="panel"):
227
  with gr.Column():
228
  with gr.Row():
229
- num_images = gr.Slider(
230
- label="Number of Images", minimum=1, maximum=8, step=1, value=4, interactive=True)
231
- prompt = gr.Text(
232
- label="Prompt", value="a photo of a cat", interactive=True)
233
- generate_2d_btn = gr.Button(value="Generate 2D Images")
234
-
235
- with gr.Row():
236
- generated_images = gr.Gallery(height=512)
237
-
238
- with gr.Row():
239
- selected_image = gr.Image(
240
- label="Selected Image",
241
  image_mode="RGBA",
242
  type="pil",
243
  interactive=False
244
  )
245
-
246
  with gr.Row():
247
  with gr.Group():
248
  do_remove_background = gr.Checkbox(
@@ -253,8 +196,18 @@ with gr.Blocks() as demo:
253
  label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
254
 
255
  with gr.Row():
256
- generate_3d_btn = gr.Button(
257
- "Generate 3D Model", elem_id="generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
258
 
259
  with gr.Column():
260
  with gr.Row():
@@ -288,24 +241,13 @@ with gr.Blocks() as demo:
288
 
289
  mv_images = gr.State()
290
 
291
- generate_2d_btn.click(
292
- fn=process_image,
293
- inputs=[num_images, prompt],
294
- outputs=[generated_images]
295
- )
296
-
297
- def select_image(evt: gr.SelectData):
298
- return evt.value['image']['url']
299
-
300
- generated_images.select(select_image, None, selected_image)
301
-
302
- generate_3d_btn.click(fn=check_input_image, inputs=[selected_image]).success(
303
  fn=preprocess,
304
- inputs=[selected_image, do_remove_background],
305
- outputs=[selected_image],
306
  ).success(
307
  fn=generate_mvs,
308
- inputs=[selected_image, sample_steps, sample_seed],
309
  outputs=[mv_images, mv_show_images]
310
  ).success(
311
  fn=make3d,
@@ -313,4 +255,4 @@ with gr.Blocks() as demo:
313
  outputs=[output_model_obj, output_model_glb]
314
  )
315
 
316
- demo.launch()
 
1
  import os
2
  import shutil
3
  import tempfile
 
 
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import rembg
8
  import spaces
9
  import torch
10
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
11
  from einops import rearrange
12
  from huggingface_hub import hf_hub_download
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
  from pytorch_lightning import seed_everything
 
16
  from torchvision.transforms import v2
17
  from tqdm import tqdm
18
 
 
22
  from src.utils.mesh_util import save_glb, save_obj
23
  from src.utils.train_util import instantiate_from_config
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def find_cuda():
27
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
 
52
 
53
  def check_input_image(input_image):
54
  if input_image is None:
55
+ raise gr.Error("No image uploaded!")
56
 
57
 
58
  def preprocess(input_image, do_remove_background):
 
125
  return mesh_fpath, mesh_glb_fpath
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Configuration
129
  cuda_path = find_cuda()
130
  config_path = 'configs/instant-mesh-large.yaml'
 
166
 
167
  model = model.to(device)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  print('Loading Finished!')
170
 
171
  # Gradio UI
 
173
  with gr.Row(variant="panel"):
174
  with gr.Column():
175
  with gr.Row():
176
+ input_image = gr.Image(
177
+ label="Input Image",
178
+ image_mode="RGBA",
179
+ sources="upload",
180
+ type="pil",
181
+ elem_id="content_image",
182
+ )
183
+ processed_image = gr.Image(
184
+ label="Processed Image",
 
 
 
185
  image_mode="RGBA",
186
  type="pil",
187
  interactive=False
188
  )
 
189
  with gr.Row():
190
  with gr.Group():
191
  do_remove_background = gr.Checkbox(
 
196
  label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
197
 
198
  with gr.Row():
199
+ submit = gr.Button(
200
+ "Generate", elem_id="generate", variant="primary")
201
+
202
+ with gr.Row(variant="panel"):
203
+ gr.Examples(
204
+ examples=[os.path.join("examples", img_name)
205
+ for img_name in sorted(os.listdir("examples"))],
206
+ inputs=[input_image],
207
+ label="Examples",
208
+ cache_examples=False,
209
+ examples_per_page=16
210
+ )
211
 
212
  with gr.Column():
213
  with gr.Row():
 
241
 
242
  mv_images = gr.State()
243
 
244
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
 
 
 
 
 
 
 
 
 
 
 
245
  fn=preprocess,
246
+ inputs=[input_image, do_remove_background],
247
+ outputs=[processed_image],
248
  ).success(
249
  fn=generate_mvs,
250
+ inputs=[processed_image, sample_steps, sample_seed],
251
  outputs=[mv_images, mv_show_images]
252
  ).success(
253
  fn=make3d,
 
255
  outputs=[output_model_obj, output_model_glb]
256
  )
257
 
258
+ demo.launch()