jw2yang commited on
Commit
aa19070
1 Parent(s): 7eeb189

paste inpainted image back to original image to make sure size is unchanged

Browse files
Files changed (1) hide show
  1. tasks/ref_in.py +9 -9
tasks/ref_in.py CHANGED
@@ -37,7 +37,7 @@ def crop_image(input_image):
37
  def referring_inpainting(model, image, texts, inpainting_text, *args, **kwargs):
38
  model.model.metadata = metadata
39
  texts = [[texts if texts.strip().endswith('.') else (texts.strip() + '.')]]
40
- image_ori = crop_image(transform(image))
41
 
42
  with torch.no_grad():
43
  width = image_ori.size[0]
@@ -58,20 +58,20 @@ def referring_inpainting(model, image, texts, inpainting_text, *args, **kwargs):
58
 
59
  if inpainting_text not in ['no', '']:
60
  # if we want to do inpainting
61
- image_ori = image_ori.convert('RGB')
62
  struct2 = ndimage.generate_binary_structure(2, 2)
63
  mask_dilated = ndimage.binary_dilation(grd_mask[0], structure=struct2, iterations=3).astype(grd_mask[0].dtype)
64
- mask = Image.fromarray(mask_dilated * 255).convert('RGB')
65
- # image_ori = pad_image(image_ori)
66
- # mask = pad_image(Image.fromarray(grd_mask[0] * 255).convert('RGB'))
67
  image_and_mask = {
68
- "image": image_ori,
69
  "mask": mask,
70
  }
71
- width = image_ori.size[0]; height = image_ori.size[1]
72
- images_inpainting = pipe(prompt = inpainting_text.strip(), image=image_and_mask['image'], mask_image=image_and_mask['mask'], height=height, width=width).images
 
 
73
  torch.cuda.empty_cache()
74
- return Image.fromarray(res) ,'' , images_inpainting[0]
75
  else:
76
  torch.cuda.empty_cache()
77
  return image_ori, 'text', Image.fromarray(res)
 
37
  def referring_inpainting(model, image, texts, inpainting_text, *args, **kwargs):
38
  model.model.metadata = metadata
39
  texts = [[texts if texts.strip().endswith('.') else (texts.strip() + '.')]]
40
+ image_ori = transform(image)
41
 
42
  with torch.no_grad():
43
  width = image_ori.size[0]
 
58
 
59
  if inpainting_text not in ['no', '']:
60
  # if we want to do inpainting
61
+ image_crop = crop_image(image_ori.convert('RGB'))
62
  struct2 = ndimage.generate_binary_structure(2, 2)
63
  mask_dilated = ndimage.binary_dilation(grd_mask[0], structure=struct2, iterations=3).astype(grd_mask[0].dtype)
64
+ mask = crop_image(Image.fromarray(mask_dilated * 255).convert('RGB'))
 
 
65
  image_and_mask = {
66
+ "image": image_crop,
67
  "mask": mask,
68
  }
69
+ width = image_crop.size[0]; height = image_crop.size[1]
70
+ images_inpainting = pipe(prompt = inpainting_text.strip(), image=image_and_mask['image'], mask_image=image_and_mask['mask'], height=height, width=width).images[0]
71
+ # put images_inpainting back to original image
72
+ image_ori.paste(images_inpainting)
73
  torch.cuda.empty_cache()
74
+ return Image.fromarray(res) ,'' , image_ori
75
  else:
76
  torch.cuda.empty_cache()
77
  return image_ori, 'text', Image.fromarray(res)