levihsu commited on
Commit
4bce9fa
β€’
1 Parent(s): 5a486d6

Update run/gradio_ootd.py

Browse files
Files changed (1) hide show
  1. run/gradio_ootd.py +5 -2
run/gradio_ootd.py CHANGED
@@ -10,7 +10,6 @@ from utils_ootd import get_mask_location
10
  PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
11
  sys.path.insert(0, str(PROJECT_ROOT))
12
 
13
- import time
14
  from preprocess.openpose.run_openpose import OpenPose
15
  from preprocess.humanparsing.run_parsing import Parsing
16
  from ootd.inference_ootd_hd import OOTDiffusionHD
@@ -36,6 +35,10 @@ garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
36
  model_dc = os.path.join(example_path, 'model/model_8.png')
37
  garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
38
 
 
 
 
 
39
  def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
40
  model_type = 'hd'
41
  category = 0 # 0:upperbody; 1:lowerbody; 2:dress
@@ -257,4 +260,4 @@ with block:
257
  ips_dc = [vton_img_dc, garm_img_dc, category_dc, n_samples_dc, n_steps_dc, image_scale_dc, seed_dc]
258
  run_button_dc.click(fn=process_dc, inputs=ips_dc, outputs=[result_gallery_dc])
259
 
260
- block.launch(server_name='0.0.0.0', server_port=7865)
 
10
  PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
11
  sys.path.insert(0, str(PROJECT_ROOT))
12
 
 
13
  from preprocess.openpose.run_openpose import OpenPose
14
  from preprocess.humanparsing.run_parsing import Parsing
15
  from ootd.inference_ootd_hd import OOTDiffusionHD
 
35
  model_dc = os.path.join(example_path, 'model/model_8.png')
36
  garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
37
 
38
+
39
+ import spaces
40
+
41
+ @spaces.GPU
42
  def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
43
  model_type = 'hd'
44
  category = 0 # 0:upperbody; 1:lowerbody; 2:dress
 
260
  ips_dc = [vton_img_dc, garm_img_dc, category_dc, n_samples_dc, n_steps_dc, image_scale_dc, seed_dc]
261
  run_button_dc.click(fn=process_dc, inputs=ips_dc, outputs=[result_gallery_dc])
262
 
263
+ block.launch()