xichenhku commited on
Commit
f0c81bd
·
1 Parent(s): d611412
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -78,10 +78,11 @@ def inference_single_image(ref_image,
78
  strength,
79
  ddim_steps,
80
  scale,
81
- seed,
 
82
  ):
83
  raw_background = tar_image.copy()
84
- item = process_pairs(ref_image, ref_mask, tar_image, tar_mask)
85
 
86
  ref = item['ref']
87
  hint = item['hint']
@@ -133,7 +134,7 @@ def inference_single_image(ref_image,
133
  return raw_background
134
 
135
 
136
- def process_pairs(ref_image, ref_mask, tar_image, tar_mask, max_ratio = 0.8):
137
  # ========= Reference ===========
138
  # ref expand
139
  ref_box_yyxx = get_bbox_from_mask(ref_mask)
@@ -189,21 +190,23 @@ def process_pairs(ref_image, ref_mask, tar_image, tar_mask, max_ratio = 0.8):
189
 
190
  collage_mask = cropped_target_image.copy() * 0.0
191
  collage_mask[y1:y2,x1:x2,:] = 1.0
192
- collage_mask = np.stack([cropped_tar_mask,cropped_tar_mask,cropped_tar_mask],-1)
 
193
 
194
  # the size before pad
195
  H1, W1 = collage.shape[0], collage.shape[1]
196
 
197
  cropped_target_image = pad_to_square(cropped_target_image, pad_value = 0, random = False).astype(np.uint8)
198
  collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
199
- collage_mask = pad_to_square(collage_mask, pad_value = 0, random = False).astype(np.uint8)
200
 
201
  # the size after pad
202
  H2, W2 = collage.shape[0], collage.shape[1]
203
 
204
  cropped_target_image = cv2.resize(cropped_target_image.astype(np.uint8), (512,512)).astype(np.float32)
205
  collage = cv2.resize(collage.astype(np.uint8), (512,512)).astype(np.float32)
206
- collage_mask = (cv2.resize(collage_mask.astype(np.uint8), (512,512)).astype(np.float32) > 0.5).astype(np.float32)
 
207
 
208
  masked_ref_image = masked_ref_image / 255
209
  cropped_target_image = cropped_target_image / 127.5 - 1.0
@@ -225,13 +228,6 @@ ref_list.sort()
225
  image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
226
  image_list.sort()
227
 
228
- def process_image_mask(image_np, mask_np):
229
- img = torch.from_numpy(image_np.transpose((2, 0, 1)))
230
- img_ten = img.float().div(255).unsqueeze(0)
231
- mask_ten = torch.from_numpy(mask_np).float().unsqueeze(0).unsqueeze(0)
232
- return img_ten, mask_ten
233
-
234
-
235
  def mask_image(image, mask):
236
  blanc = np.ones_like(image) * 255
237
  mask = np.stack([mask,mask,mask],-1) / 255
@@ -247,49 +243,38 @@ def run_local(base,
247
  ref_mask = ref["mask"].convert("L")
248
  image = np.asarray(image)
249
  mask = np.asarray(mask)
250
- mask = np.where(mask > 128, 255, 0).astype(np.uint8)
251
  ref_image = np.asarray(ref_image)
252
  ref_mask = np.asarray(ref_mask)
253
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
254
 
255
- # refine the user annotated coarse mask
256
- if use_interactive_seg:
257
- img_ten, mask_ten = process_image_mask(ref_image, ref_mask)
258
- ref_mask = iseg_model(img_ten, mask_ten)['instances'][0,0].detach().numpy() > 0.5
259
-
260
- processed_item = process_pairs(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), max_ratio = 0.8)
261
- masked_ref = (processed_item['ref']*255)
262
-
263
- mased_image = mask_image(image, mask)
264
- #synthesis = image
265
  synthesis = inference_single_image(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), *args)
266
  synthesis = torch.from_numpy(synthesis).permute(2, 0, 1)
267
  synthesis = synthesis.permute(1, 2, 0).numpy()
268
-
269
- masked_ref = cv2.resize(masked_ref.astype(np.uint8), (512,512))
270
  return [synthesis]
271
 
 
 
272
  with gr.Blocks() as demo:
273
  with gr.Column():
274
  gr.Markdown("# Play with AnyDoor to Teleport your Target Objects! ")
275
  with gr.Row():
276
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
277
  with gr.Accordion("Advanced Option", open=True):
278
- #num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
279
  num_samples = 1
280
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
281
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
282
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1)
283
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
 
284
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower guidance-scale leads to more harmonized blending.")
285
 
286
-
287
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
288
  gr.Markdown("### Your could draw coarse masks on the background to indicate the desired location and shape.")
289
  gr.Markdown("### <u>Do not forget</u> to annotate the target object on the reference image.")
290
  with gr.Row():
291
- base = gr.Image(label="Background", tool="sketch", type="pil", height=512, brush_color='#FFFFFF', mask_opacity=0.5)
292
- ref = gr.Image(label="Reference", tool="sketch", type="pil", height=512, brush_color='#FFFFFF', mask_opacity=0.5)
293
  run_local_button = gr.Button(label="Generate", value="Run")
294
 
295
  with gr.Row():
@@ -304,7 +289,8 @@ with gr.Blocks() as demo:
304
  strength,
305
  ddim_steps,
306
  scale,
307
- seed,
 
308
  ],
309
  outputs=[baseline_gallery]
310
  )
 
78
  strength,
79
  ddim_steps,
80
  scale,
81
+ seed,
82
+ enable_shape_control
83
  ):
84
  raw_background = tar_image.copy()
85
+ item = process_pairs(ref_image, ref_mask, tar_image, tar_mask, enable_shape_control = enable_shape_control)
86
 
87
  ref = item['ref']
88
  hint = item['hint']
 
134
  return raw_background
135
 
136
 
137
+ def process_pairs(ref_image, ref_mask, tar_image, tar_mask, max_ratio = 0.8, enable_shape_control = False):
138
  # ========= Reference ===========
139
  # ref expand
140
  ref_box_yyxx = get_bbox_from_mask(ref_mask)
 
190
 
191
  collage_mask = cropped_target_image.copy() * 0.0
192
  collage_mask[y1:y2,x1:x2,:] = 1.0
193
+ if enable_shape_control:
194
+ collage_mask = np.stack([cropped_tar_mask,cropped_tar_mask,cropped_tar_mask],-1)
195
 
196
  # the size before pad
197
  H1, W1 = collage.shape[0], collage.shape[1]
198
 
199
  cropped_target_image = pad_to_square(cropped_target_image, pad_value = 0, random = False).astype(np.uint8)
200
  collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
201
+ collage_mask = pad_to_square(collage_mask, pad_value = 2, random = False).astype(np.uint8)
202
 
203
  # the size after pad
204
  H2, W2 = collage.shape[0], collage.shape[1]
205
 
206
  cropped_target_image = cv2.resize(cropped_target_image.astype(np.uint8), (512,512)).astype(np.float32)
207
  collage = cv2.resize(collage.astype(np.uint8), (512,512)).astype(np.float32)
208
+ collage_mask = cv2.resize(collage_mask.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
209
+ collage_mask[collage_mask == 2] = -1
210
 
211
  masked_ref_image = masked_ref_image / 255
212
  cropped_target_image = cropped_target_image / 127.5 - 1.0
 
228
  image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
229
  image_list.sort()
230
 
 
 
 
 
 
 
 
231
  def mask_image(image, mask):
232
  blanc = np.ones_like(image) * 255
233
  mask = np.stack([mask,mask,mask],-1) / 255
 
243
  ref_mask = ref["mask"].convert("L")
244
  image = np.asarray(image)
245
  mask = np.asarray(mask)
246
+ mask = np.where(mask > 128, 1, 0).astype(np.uint8)
247
  ref_image = np.asarray(ref_image)
248
  ref_mask = np.asarray(ref_mask)
249
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
250
 
 
 
 
 
 
 
 
 
 
 
251
  synthesis = inference_single_image(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), *args)
252
  synthesis = torch.from_numpy(synthesis).permute(2, 0, 1)
253
  synthesis = synthesis.permute(1, 2, 0).numpy()
 
 
254
  return [synthesis]
255
 
256
+
257
+
258
  with gr.Blocks() as demo:
259
  with gr.Column():
260
  gr.Markdown("# Play with AnyDoor to Teleport your Target Objects! ")
261
  with gr.Row():
262
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
263
  with gr.Accordion("Advanced Option", open=True):
 
264
  num_samples = 1
265
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
266
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
267
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1)
268
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
269
+ enable_shape_control = gr.Checkbox(label='Enable Shape Control', value=False)
270
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower guidance-scale leads to more harmonized blending.")
271
 
 
272
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
273
  gr.Markdown("### Your could draw coarse masks on the background to indicate the desired location and shape.")
274
  gr.Markdown("### <u>Do not forget</u> to annotate the target object on the reference image.")
275
  with gr.Row():
276
+ base = gr.Image(label="Background", source="upload", tool="sketch", type="pil", height=512, brush_color='#FFFFFF', mask_opacity=0.5)
277
+ ref = gr.Image(label="Reference", source="upload", tool="sketch", type="pil", height=512, brush_color='#FFFFFF', mask_opacity=0.5)
278
  run_local_button = gr.Button(label="Generate", value="Run")
279
 
280
  with gr.Row():
 
289
  strength,
290
  ddim_steps,
291
  scale,
292
+ seed,
293
+ enable_shape_control,
294
  ],
295
  outputs=[baseline_gallery]
296
  )