qlz58793 commited on
Commit
469f43d
1 Parent(s): c331e65

fast version

Browse files
Files changed (2) hide show
  1. app.py +59 -20
  2. lama_inpaint.py +70 -0
app.py CHANGED
@@ -5,12 +5,13 @@ from matplotlib import pyplot as plt
5
  import torch
6
  import tempfile
7
  import os
 
8
  from sam_segment import predict_masks_with_sam
9
- from lama_inpaint import inpaint_img_with_lama
10
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
  show_mask, show_points
12
  from PIL import Image
13
-
14
 
15
  def mkstemp(suffix, dir=None):
16
  fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
@@ -18,19 +19,21 @@ def mkstemp(suffix, dir=None):
18
  return Path(path)
19
 
20
 
 
 
 
 
 
 
21
  def get_masked_img(img, w, h):
22
- point_labels = [1]
23
  point_coords = [w, h]
 
24
  dilate_kernel_size = 15
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
-
27
- masks, _, _ = predict_masks_with_sam(
28
- img,
29
- [point_coords],
30
- point_labels,
31
- model_type="vit_h",
32
- ckpt_p="pretrained_models/sam_vit_h_4b8939.pth",
33
- device=device,
34
  )
35
 
36
  masks = masks.astype(np.uint8) * 255
@@ -67,22 +70,45 @@ def get_inpainted_img(img, mask0, mask1, mask2):
67
  for mask in [mask0, mask1, mask2]:
68
  if len(mask.shape)==3:
69
  mask = mask[:,:,0]
70
- img_inpainted = inpaint_img_with_lama(
71
- img, mask, lama_config, lama_ckpt, device=device)
72
  out.append(img_inpainted)
73
  return out
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with gr.Blocks() as demo:
77
  with gr.Row():
78
  img = gr.Image(label="Image")
 
 
79
  with gr.Column():
80
  with gr.Row():
81
  w = gr.Number(label="Point Coordinate W")
82
  h = gr.Number(label="Point Coordinate H")
83
- sam = gr.Button("Predict Mask Using SAM")
 
84
  lama = gr.Button("Inpaint Image Using LaMA")
85
 
 
86
  with gr.Row():
87
  mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0")
88
  mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1")
@@ -101,11 +127,23 @@ with gr.Blocks() as demo:
101
  img_rm_with_mask_2 = gr.outputs.Image(
102
  type="numpy", label="Image Removed with Segmentation Mask 2")
103
 
104
- def get_select_coords(evt: gr.SelectData):
105
- return evt.index[0], evt.index[1]
 
 
 
 
 
 
 
106
 
107
- img.select(get_select_coords, [], [w, h])
108
- sam.click(
 
 
 
 
 
109
  get_masked_img,
110
  [img, w, h],
111
  [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
@@ -119,4 +157,5 @@ with gr.Blocks() as demo:
119
 
120
 
121
  if __name__ == "__main__":
122
- demo.launch()
 
 
5
  import torch
6
  import tempfile
7
  import os
8
+ from omegaconf import OmegaConf
9
  from sam_segment import predict_masks_with_sam
10
+ from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
11
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
12
  show_mask, show_points
13
  from PIL import Image
14
+ from segment_anything import SamPredictor, sam_model_registry
15
 
16
  def mkstemp(suffix, dir=None):
17
  fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
 
19
  return Path(path)
20
 
21
 
22
+ def get_sam_feat(img):
23
+ # predictor.set_image(img)
24
+ model['sam'].set_image(img)
25
+ return
26
+
27
+
28
  def get_masked_img(img, w, h):
 
29
  point_coords = [w, h]
30
+ point_labels = [1]
31
  dilate_kernel_size = 15
32
+ # masks, _, _ = predictor.predict(
33
+ masks, _, _ = model['sam'].predict(
34
+ point_coords=np.array([point_coords]),
35
+ point_labels=np.array(point_labels),
36
+ multimask_output=True,
 
 
 
 
37
  )
38
 
39
  masks = masks.astype(np.uint8) * 255
 
70
  for mask in [mask0, mask1, mask2]:
71
  if len(mask.shape)==3:
72
  mask = mask[:,:,0]
73
+ img_inpainted = inpaint_img_with_builded_lama(
74
+ model_lama, img, mask, lama_config, device=device)
75
  out.append(img_inpainted)
76
  return out
77
 
78
 
79
+ ## build models
80
+ model = {}
81
+ # build the sam model
82
+ model_type="vit_h"
83
+ ckpt_p="pretrained_models/sam_vit_h_4b8939.pth"
84
+ model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ model_sam.to(device=device)
87
+ # predictor = SamPredictor(model_sam)
88
+ model['sam'] = SamPredictor(model_sam)
89
+
90
+ # build the lama model
91
+ lama_config = "third_party/lama/configs/prediction/default.yaml"
92
+ lama_ckpt = "pretrained_models/big-lama"
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ # model_lama = build_lama_model(lama_config, lama_ckpt, device=device)
95
+ model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
96
+
97
+
98
  with gr.Blocks() as demo:
99
  with gr.Row():
100
  img = gr.Image(label="Image")
101
+ # img_pointed = gr.Image(label='Pointed Image')
102
+ img_pointed = gr.Plot(label='Pointed Image')
103
  with gr.Column():
104
  with gr.Row():
105
  w = gr.Number(label="Point Coordinate W")
106
  h = gr.Number(label="Point Coordinate H")
107
+ sam_feat = gr.Button("Generate Features Using SAM")
108
+ sam_mask = gr.Button("Predict Mask Using SAM")
109
  lama = gr.Button("Inpaint Image Using LaMA")
110
 
111
+ # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
112
  with gr.Row():
113
  mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0")
114
  mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1")
 
127
  img_rm_with_mask_2 = gr.outputs.Image(
128
  type="numpy", label="Image Removed with Segmentation Mask 2")
129
 
130
+ def get_select_coords(img, evt: gr.SelectData):
131
+ dpi = plt.rcParams['figure.dpi']
132
+ height, width = img.shape[:2]
133
+ fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
134
+ plt.imshow(img)
135
+ plt.axis('off')
136
+ show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
137
+ size=(width*0.04)**2)
138
+ return evt.index[0], evt.index[1], fig
139
 
140
+ img.select(get_select_coords, [img], [w, h, img_pointed])
141
+ sam_feat.click(
142
+ get_sam_feat,
143
+ [img],
144
+ []
145
+ )
146
+ sam_mask.click(
147
  get_masked_img,
148
  [img, w, h],
149
  [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
 
157
 
158
 
159
  if __name__ == "__main__":
160
+ demo.launch(debug=True)
161
+
lama_inpaint.py CHANGED
@@ -82,6 +82,76 @@ def inpaint_img_with_lama(
82
  cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
83
  return cur_res
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def setup_args(parser):
86
  parser.add_argument(
87
  "--input_img", type=str, required=True,
 
82
  cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
83
  return cur_res
84
 
85
+
86
+ def build_lama_model(
87
+ config_p: str,
88
+ ckpt_p: str,
89
+ device="cuda"
90
+ ):
91
+ predict_config = OmegaConf.load(config_p)
92
+ predict_config.model.path = ckpt_p
93
+ # device = torch.device(predict_config.device)
94
+ device = torch.device(device)
95
+
96
+ train_config_path = os.path.join(
97
+ predict_config.model.path, 'config.yaml')
98
+
99
+ with open(train_config_path, 'r') as f:
100
+ train_config = OmegaConf.create(yaml.safe_load(f))
101
+
102
+ train_config.training_model.predict_only = True
103
+ train_config.visualizer.kind = 'noop'
104
+
105
+ checkpoint_path = os.path.join(
106
+ predict_config.model.path, 'models',
107
+ predict_config.model.checkpoint
108
+ )
109
+ model = load_checkpoint(
110
+ train_config, checkpoint_path, strict=False, map_location=device)
111
+ model.freeze()
112
+ if not predict_config.get('refine', False):
113
+ model.to(device)
114
+
115
+ return model
116
+
117
+
118
+ @torch.no_grad()
119
+ def inpaint_img_with_builded_lama(
120
+ model,
121
+ img: np.ndarray,
122
+ mask: np.ndarray,
123
+ config_p: str,
124
+ mod=8,
125
+ device="cuda"
126
+ ):
127
+ assert len(mask.shape) == 2
128
+ if np.max(mask) == 1:
129
+ mask = mask * 255
130
+ img = torch.from_numpy(img).float().div(255.)
131
+ mask = torch.from_numpy(mask).float()
132
+ predict_config = OmegaConf.load(config_p)
133
+
134
+ batch = {}
135
+ batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
136
+ batch['mask'] = mask[None, None]
137
+ unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
138
+ batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
139
+ batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
140
+ batch = move_to_device(batch, device)
141
+ batch['mask'] = (batch['mask'] > 0) * 1
142
+
143
+ batch = model(batch)
144
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
145
+ cur_res = cur_res.detach().cpu().numpy()
146
+
147
+ if unpad_to_size is not None:
148
+ orig_height, orig_width = unpad_to_size
149
+ cur_res = cur_res[:orig_height, :orig_width]
150
+
151
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
152
+ return cur_res
153
+
154
+
155
  def setup_args(parser):
156
  parser.add_argument(
157
  "--input_img", type=str, required=True,