xichenhku commited on
Commit
fdd218f
1 Parent(s): d3283da

gradio_back

Browse files
Files changed (1) hide show
  1. app.py +5 -27
app.py CHANGED
@@ -49,14 +49,6 @@ if use_interactive_seg:
49
  iseg_model.load_state_dict(weights, strict= True)
50
 
51
 
52
-
53
- def process_image_mask(image_np, mask_np):
54
- img = torch.from_numpy(image_np.transpose((2, 0, 1)))
55
- img = img.float().div(255).unsqueeze(0)
56
- mask = torch.from_numpy(mask_np).float().unsqueeze(0).unsqueeze(0)
57
- pred = iseg_model(img, mask)['instances'][0,0].detach().numpy() > 0.5
58
- return pred.astype(np.uint8)
59
-
60
  def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
61
  H1, W1, H2, W2 = extra_sizes
62
  y1,y2,x1,x2 = tar_box_yyxx_crop
@@ -78,6 +70,7 @@ def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
78
  tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
79
  return tar_image
80
 
 
81
  def inference_single_image(ref_image,
82
  ref_mask,
83
  tar_image,
@@ -86,7 +79,7 @@ def inference_single_image(ref_image,
86
  ddim_steps,
87
  scale,
88
  seed,
89
- enable_shape_control,
90
  ):
91
  raw_background = tar_image.copy()
92
  item = process_pairs(ref_image, ref_mask, tar_image, tar_mask, enable_shape_control = enable_shape_control)
@@ -255,15 +248,6 @@ def run_local(base,
255
  ref_mask = np.asarray(ref_mask)
256
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
257
 
258
- if ref_mask.sum() == 0:
259
- raise gr.Error('No mask for the reference image.')
260
-
261
- if mask.sum() == 0:
262
- raise gr.Error('No mask for the background image.')
263
-
264
- if reference_mask_refine:
265
- ref_mask = process_image_mask(ref_image, ref_mask)
266
-
267
  synthesis = inference_single_image(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), *args)
268
  synthesis = torch.from_numpy(synthesis).permute(2, 0, 1)
269
  synthesis = synthesis.permute(1, 2, 0).numpy()
@@ -282,16 +266,12 @@ with gr.Blocks() as demo:
282
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
283
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1)
284
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
285
- reference_mask_refine = gr.Checkbox(label='Reference Mask Refine', value=True, interactive = True)
286
- enable_shape_control = gr.Checkbox(label='Enable Shape Control', value=False, interactive = True)
287
-
288
  gr.Markdown("### Guidelines")
289
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower one makes more harmonized blending.")
290
- gr.Markdown(" Users should annotate the mask of the target object, too coarse mask would lead to bad generation.\
291
- Reference Mask Refine provides a segmentation model to refine the coarse mask. ")
292
- gr.Markdown(" Enable shape control means the generation results would consider user-drawn masks to control the shape & pose; otherwise it \
293
  considers the location and size to adjust automatically.")
294
-
295
 
296
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
297
  gr.Markdown("### You could draw coarse masks on the background to indicate the desired location and shape.")
@@ -318,6 +298,4 @@ with gr.Blocks() as demo:
318
  ],
319
  outputs=[baseline_gallery]
320
  )
321
-
322
-
323
  demo.launch()
 
49
  iseg_model.load_state_dict(weights, strict= True)
50
 
51
 
 
 
 
 
 
 
 
 
52
  def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
53
  H1, W1, H2, W2 = extra_sizes
54
  y1,y2,x1,x2 = tar_box_yyxx_crop
 
70
  tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
71
  return tar_image
72
 
73
+
74
  def inference_single_image(ref_image,
75
  ref_mask,
76
  tar_image,
 
79
  ddim_steps,
80
  scale,
81
  seed,
82
+ enable_shape_control
83
  ):
84
  raw_background = tar_image.copy()
85
  item = process_pairs(ref_image, ref_mask, tar_image, tar_mask, enable_shape_control = enable_shape_control)
 
248
  ref_mask = np.asarray(ref_mask)
249
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
250
 
 
 
 
 
 
 
 
 
 
251
  synthesis = inference_single_image(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), *args)
252
  synthesis = torch.from_numpy(synthesis).permute(2, 0, 1)
253
  synthesis = synthesis.permute(1, 2, 0).numpy()
 
266
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
267
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1)
268
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
269
+ enable_shape_control = gr.Checkbox(label='Enable Shape Control', value=False)
 
 
270
  gr.Markdown("### Guidelines")
271
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower one makes more harmonized blending.")
272
+ gr.Markdown(" Enable shape control means the generation results would consider user-drawn masks; otherwise it \
 
 
273
  considers the location and size to adjust automatically.")
274
+ gr.Markdown(" Users should annotate the mask of the target object, too coarse mask would lead to bad generation.")
275
 
276
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
277
  gr.Markdown("### You could draw coarse masks on the background to indicate the desired location and shape.")
 
298
  ],
299
  outputs=[baseline_gallery]
300
  )
 
 
301
  demo.launch()