xichenhku commited on
Commit
d3283da
1 Parent(s): 196cdd9

gradio_update

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -49,6 +49,14 @@ if use_interactive_seg:
49
  iseg_model.load_state_dict(weights, strict= True)
50
 
51
 
 
 
 
 
 
 
 
 
52
  def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
53
  H1, W1, H2, W2 = extra_sizes
54
  y1,y2,x1,x2 = tar_box_yyxx_crop
@@ -70,7 +78,6 @@ def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
70
  tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
71
  return tar_image
72
 
73
-
74
  def inference_single_image(ref_image,
75
  ref_mask,
76
  tar_image,
@@ -79,7 +86,7 @@ def inference_single_image(ref_image,
79
  ddim_steps,
80
  scale,
81
  seed,
82
- enable_shape_control
83
  ):
84
  raw_background = tar_image.copy()
85
  item = process_pairs(ref_image, ref_mask, tar_image, tar_mask, enable_shape_control = enable_shape_control)
@@ -248,6 +255,15 @@ def run_local(base,
248
  ref_mask = np.asarray(ref_mask)
249
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
250
 
 
 
 
 
 
 
 
 
 
251
  synthesis = inference_single_image(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), *args)
252
  synthesis = torch.from_numpy(synthesis).permute(2, 0, 1)
253
  synthesis = synthesis.permute(1, 2, 0).numpy()
@@ -266,12 +282,16 @@ with gr.Blocks() as demo:
266
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
267
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1)
268
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
269
- enable_shape_control = gr.Checkbox(label='Enable Shape Control', value=False)
 
 
270
  gr.Markdown("### Guidelines")
271
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower one makes more harmonized blending.")
272
- gr.Markdown(" Enable shape control means the generation results would consider user-drawn masks; otherwise it \
 
 
273
  considers the location and size to adjust automatically.")
274
- gr.Markdown(" Users should annotate the mask of the target object, too coarse mask would lead to bad generation.")
275
 
276
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
277
  gr.Markdown("### You could draw coarse masks on the background to indicate the desired location and shape.")
@@ -298,4 +318,6 @@ with gr.Blocks() as demo:
298
  ],
299
  outputs=[baseline_gallery]
300
  )
 
 
301
  demo.launch()
 
49
  iseg_model.load_state_dict(weights, strict= True)
50
 
51
 
52
+
53
+ def process_image_mask(image_np, mask_np):
54
+ img = torch.from_numpy(image_np.transpose((2, 0, 1)))
55
+ img = img.float().div(255).unsqueeze(0)
56
+ mask = torch.from_numpy(mask_np).float().unsqueeze(0).unsqueeze(0)
57
+ pred = iseg_model(img, mask)['instances'][0,0].detach().numpy() > 0.5
58
+ return pred.astype(np.uint8)
59
+
60
  def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
61
  H1, W1, H2, W2 = extra_sizes
62
  y1,y2,x1,x2 = tar_box_yyxx_crop
 
78
  tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
79
  return tar_image
80
 
 
81
  def inference_single_image(ref_image,
82
  ref_mask,
83
  tar_image,
 
86
  ddim_steps,
87
  scale,
88
  seed,
89
+ enable_shape_control,
90
  ):
91
  raw_background = tar_image.copy()
92
  item = process_pairs(ref_image, ref_mask, tar_image, tar_mask, enable_shape_control = enable_shape_control)
 
255
  ref_mask = np.asarray(ref_mask)
256
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
257
 
258
+ if ref_mask.sum() == 0:
259
+ raise gr.Error('No mask for the reference image.')
260
+
261
+ if mask.sum() == 0:
262
+ raise gr.Error('No mask for the background image.')
263
+
264
+ if reference_mask_refine:
265
+ ref_mask = process_image_mask(ref_image, ref_mask)
266
+
267
  synthesis = inference_single_image(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), *args)
268
  synthesis = torch.from_numpy(synthesis).permute(2, 0, 1)
269
  synthesis = synthesis.permute(1, 2, 0).numpy()
 
282
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
283
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1)
284
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
285
+ reference_mask_refine = gr.Checkbox(label='Reference Mask Refine', value=True, interactive = True)
286
+ enable_shape_control = gr.Checkbox(label='Enable Shape Control', value=False, interactive = True)
287
+
288
  gr.Markdown("### Guidelines")
289
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower one makes more harmonized blending.")
290
+ gr.Markdown(" Users should annotate the mask of the target object, too coarse mask would lead to bad generation.\
291
+ Reference Mask Refine provides a segmentation model to refine the coarse mask. ")
292
+ gr.Markdown(" Enable shape control means the generation results would consider user-drawn masks to control the shape & pose; otherwise it \
293
  considers the location and size to adjust automatically.")
294
+
295
 
296
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
297
  gr.Markdown("### You could draw coarse masks on the background to indicate the desired location and shape.")
 
318
  ],
319
  outputs=[baseline_gallery]
320
  )
321
+
322
+
323
  demo.launch()