liuyuan-pal commited on
Commit
a500583
1 Parent(s): 979cb09
Files changed (2) hide show
  1. app.py +24 -7
  2. sam_utils.py +1 -1
app.py CHANGED
@@ -82,11 +82,28 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
82
  results = np.concatenate(results, 0)
83
  return Image.fromarray(results)
84
 
 
 
 
 
 
 
85
  def sam_predict(predictor, raw_im):
86
- h, w = raw_im.height, raw_im.width
87
- add_margin(raw_im, size=max(h, w))
88
- raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
89
- image_sam = sam_out_nosave(predictor, raw_im.convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
90
  torch.cuda.empty_cache()
91
  return image_sam
92
 
@@ -152,8 +169,8 @@ def run_demo():
152
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
153
  elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
154
  cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
155
- # sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
156
- # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
157
  seed = gr.Number(6033, label='Random seed', interactive=True)
158
  run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
159
  fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
@@ -169,7 +186,7 @@ def run_demo():
169
  crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
170
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
171
 
172
- run_btn.click(partial(generate, model, 16, 1), inputs=[cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
173
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
174
 
175
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
 
82
  results = np.concatenate(results, 0)
83
  return Image.fromarray(results)
84
 
85
+ def white_background(img):
86
+ img = np.asarray(img,np.float32)/255
87
+ rgb = img[:,:,3:] * img[:,:,:3] + 1 - img[:,:,3:]
88
+ rgb = (rgb*255).astype(np.uint8)
89
+ return Image.fromarray(rgb)
90
+
91
  def sam_predict(predictor, raw_im):
92
+ raw_im = np.asarray(raw_im)
93
+ raw_rgb = white_background(raw_im)
94
+ h, w = raw_im.raw_rgb, raw_im.raw_rgb
95
+ raw_rgb = add_margin(raw_rgb, color=255, size=max(h, w))
96
+
97
+ raw_rgb.thumbnail([512, 512], Image.Resampling.LANCZOS)
98
+ image_sam = sam_out_nosave(predictor, raw_rgb.convert("RGB"))
99
+
100
+ image_sam = np.asarray(image_sam)
101
+ out_mask = image_sam[:,:,3:]>0
102
+ out_rgb = image_sam[:,:,:3] * out_mask + 1 - out_mask
103
+ out_mask = out_mask.astype(np.uint8) * 255
104
+ out_img = np.concatenate([out_rgb, out_mask], 2)
105
+
106
+ image_sam = Image.fromarray(out_img, mode='RGBA')
107
  torch.cuda.empty_cache()
108
  return image_sam
109
 
 
169
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
170
  elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
171
  cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
172
+ sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
173
+ batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
174
  seed = gr.Number(6033, label='Random seed', interactive=True)
175
  run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
176
  fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
 
186
  crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
187
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
188
 
189
+ run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
190
  .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
191
 
192
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
sam_utils.py CHANGED
@@ -19,7 +19,7 @@ def sam_init(device_id=0):
19
  def sam_out_nosave(predictor, input_image, ):
20
  image = np.asarray(input_image)
21
  h, w, _ = image.shape
22
- bbox = np.array([0, 0, w, h])
23
 
24
  start_time = time.time()
25
  predictor.set_image(image)
 
19
  def sam_out_nosave(predictor, input_image, ):
20
  image = np.asarray(input_image)
21
  h, w, _ = image.shape
22
+ bbox = np.array([0, 0, h, w])
23
 
24
  start_time = time.time()
25
  predictor.set_image(image)