pg56714 commited on
Commit
7e0cd5c
1 Parent(s): 8c79ca9

Upload 13 files

Browse files
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (179 Bytes). View file
 
utils/__pycache__/crop_for_replacing.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
utils/__pycache__/mask_processing.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
utils/crop_for_replacing.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from typing import Tuple
4
+
5
+ def resize_and_pad(image: np.ndarray, mask: np.ndarray, target_size: int = 512) -> Tuple[np.ndarray, np.ndarray]:
6
+ """
7
+ Resizes an image and its corresponding mask to have the longer side equal to `target_size` and pads them to make them
8
+ both have the same size. The resulting image and mask have dimensions (target_size, target_size).
9
+
10
+ Args:
11
+ image: A numpy array representing the image to resize and pad.
12
+ mask: A numpy array representing the mask to resize and pad.
13
+ target_size: An integer specifying the desired size of the longer side after resizing.
14
+
15
+ Returns:
16
+ A tuple containing two numpy arrays - the resized and padded image and the resized and padded mask.
17
+ """
18
+ height, width, _ = image.shape
19
+ max_dim = max(height, width)
20
+ scale = target_size / max_dim
21
+ new_height = int(height * scale)
22
+ new_width = int(width * scale)
23
+ image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
24
+ mask_resized = cv2.resize(mask, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
25
+ pad_height = target_size - new_height
26
+ pad_width = target_size - new_width
27
+ top_pad = pad_height // 2
28
+ bottom_pad = pad_height - top_pad
29
+ left_pad = pad_width // 2
30
+ right_pad = pad_width - left_pad
31
+ image_padded = np.pad(image_resized, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), mode='constant')
32
+ mask_padded = np.pad(mask_resized, ((top_pad, bottom_pad), (left_pad, right_pad)), mode='constant')
33
+ return image_padded, mask_padded, (top_pad, bottom_pad, left_pad, right_pad)
34
+
35
+ def recover_size(image_padded: np.ndarray, mask_padded: np.ndarray, orig_size: Tuple[int, int],
36
+ padding_factors: Tuple[int, int, int, int]) -> Tuple[np.ndarray, np.ndarray]:
37
+ """
38
+ Resizes a padded and resized image and mask to the original size.
39
+
40
+ Args:
41
+ image_padded: A numpy array representing the padded and resized image.
42
+ mask_padded: A numpy array representing the padded and resized mask.
43
+ orig_size: A tuple containing two integers - the original height and width of the image before resizing and padding.
44
+
45
+ Returns:
46
+ A tuple containing two numpy arrays - the recovered image and the recovered mask with dimensions `orig_size`.
47
+ """
48
+ h,w,c = image_padded.shape
49
+ top_pad, bottom_pad, left_pad, right_pad = padding_factors
50
+ image = image_padded[top_pad:h-bottom_pad, left_pad:w-right_pad, :]
51
+ mask = mask_padded[top_pad:h-bottom_pad, left_pad:w-right_pad]
52
+ image_resized = cv2.resize(image, orig_size[::-1], interpolation=cv2.INTER_LINEAR)
53
+ mask_resized = cv2.resize(mask, orig_size[::-1], interpolation=cv2.INTER_LINEAR)
54
+ return image_resized, mask_resized
55
+
56
+
57
+
58
+
59
+ if __name__ == '__main__':
60
+
61
+ # image = cv2.imread('example/boat.jpg')
62
+ # mask = cv2.imread('example/boat_mask_2.png', cv2.IMREAD_GRAYSCALE)
63
+ # image = cv2.imread('example/groceries.jpg')
64
+ # mask = cv2.imread('example/groceries_mask_2.png', cv2.IMREAD_GRAYSCALE)
65
+ # image = cv2.imread('example/bridge.jpg')
66
+ # mask = cv2.imread('example/bridge_mask_2.png', cv2.IMREAD_GRAYSCALE)
67
+ # image = cv2.imread('example/person_umbrella.jpg')
68
+ # mask = cv2.imread('example/person_umbrella_mask_2.png', cv2.IMREAD_GRAYSCALE)
69
+ # image = cv2.imread('example/hippopotamus.jpg')
70
+ # mask = cv2.imread('example/hippopotamus_mask_1.png', cv2.IMREAD_GRAYSCALE)
71
+ image = cv2.imread('/data1/yutao/projects/IAM/Inpaint-Anything/example/fill-anything/sample5.jpeg')
72
+ mask = cv2.imread('/data1/yutao/projects/IAM/Inpaint-Anything/example/fill-anything/sample5/mask.png', cv2.IMREAD_GRAYSCALE)
73
+ print(image.shape)
74
+ print(mask.shape)
75
+ cv2.imwrite('original_image.jpg', image)
76
+ cv2.imwrite('original_mask.jpg', mask)
77
+ image_padded, mask_padded, padding_factors = resize_and_pad(image, mask)
78
+ cv2.imwrite('padded_image.png', image_padded)
79
+ cv2.imwrite('padded_mask.png', mask_padded)
80
+ print(image_padded.shape, mask_padded.shape, padding_factors)
81
+
82
+ # ^ ------------------------------------------------------------------------------------
83
+ # ^ Please conduct inpainting or filling here on the cropped image with the cropped mask
84
+ # ^ ------------------------------------------------------------------------------------
85
+
86
+ # resize and pad the image and mask
87
+
88
+ # perform some operation on the 512x512 image and mask
89
+ # ...
90
+
91
+ # recover the image and mask to the original size
92
+ height, width, _ = image.shape
93
+ image_resized, mask_resized = recover_size(image_padded, mask_padded, (height, width), padding_factors)
94
+
95
+ # save the resized and recovered image and mask
96
+ cv2.imwrite('resized_and_padded_image.png', image_padded)
97
+ cv2.imwrite('resized_and_padded_mask.png', mask_padded)
98
+ cv2.imwrite('recovered_image.png', image_resized)
99
+ cv2.imwrite('recovered_mask.png', mask_resized)
100
+
101
+
utils/get_point_coor.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ def click_event(event, x, y, flags, param):
4
+ if event == cv2.EVENT_LBUTTONDOWN:
5
+ print("Point coordinates ({}, {})".format(x, y))
6
+ img = cv2.imread("./example/remove-anything/dog.jpg")
7
+
8
+ cv2.imshow("Image", img)
9
+ cv2.setMouseCallback("Image", click_event)
10
+ cv2.waitKey(0)
11
+
12
+ cv2.destroyAllWindows()
utils/mask_processing.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from matplotlib import pyplot as plt
3
+ import PIL.Image as Image
4
+ import numpy as np
5
+
6
+
7
+ def crop_for_filling_pre(image: np.array, mask: np.array, crop_size: int = 512):
8
+ # Calculate the aspect ratio of the image
9
+ height, width = image.shape[:2]
10
+ aspect_ratio = float(width) / float(height)
11
+
12
+ # If the shorter side is less than 512, resize the image proportionally
13
+ if min(height, width) < crop_size:
14
+ if height < width:
15
+ new_height = crop_size
16
+ new_width = int(new_height * aspect_ratio)
17
+ else:
18
+ new_width = crop_size
19
+ new_height = int(new_width / aspect_ratio)
20
+
21
+ image = cv2.resize(image, (new_width, new_height))
22
+ mask = cv2.resize(mask, (new_width, new_height))
23
+
24
+ # Find the bounding box of the mask
25
+ x, y, w, h = cv2.boundingRect(mask)
26
+
27
+ # Update the height and width of the resized image
28
+ height, width = image.shape[:2]
29
+
30
+ # # If the 512x512 square cannot cover the entire mask, resize the image accordingly
31
+ if w > crop_size or h > crop_size:
32
+ # padding to square at first
33
+ if height < width:
34
+ padding = width - height
35
+ image = np.pad(image, ((padding // 2, padding - padding // 2), (0, 0), (0, 0)), 'constant')
36
+ mask = np.pad(mask, ((padding // 2, padding - padding // 2), (0, 0)), 'constant')
37
+ else:
38
+ padding = height - width
39
+ image = np.pad(image, ((0, 0), (padding // 2, padding - padding // 2), (0, 0)), 'constant')
40
+ mask = np.pad(mask, ((0, 0), (padding // 2, padding - padding // 2)), 'constant')
41
+
42
+ resize_factor = crop_size / max(w, h)
43
+ image = cv2.resize(image, (0, 0), fx=resize_factor, fy=resize_factor)
44
+ mask = cv2.resize(mask, (0, 0), fx=resize_factor, fy=resize_factor)
45
+ x, y, w, h = cv2.boundingRect(mask)
46
+
47
+ # Calculate the crop coordinates
48
+ crop_x = min(max(x + w // 2 - crop_size // 2, 0), width - crop_size)
49
+ crop_y = min(max(y + h // 2 - crop_size // 2, 0), height - crop_size)
50
+
51
+ # Crop the image
52
+ cropped_image = image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
53
+ cropped_mask = mask[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]
54
+
55
+ return cropped_image, cropped_mask
56
+
57
+
58
+ def crop_for_filling_post(
59
+ image: np.array,
60
+ mask: np.array,
61
+ filled_image: np.array,
62
+ crop_size: int = 512,
63
+ ):
64
+ image_copy = image.copy()
65
+ mask_copy = mask.copy()
66
+ # Calculate the aspect ratio of the image
67
+ height, width = image.shape[:2]
68
+ height_ori, width_ori = height, width
69
+ aspect_ratio = float(width) / float(height)
70
+
71
+ # If the shorter side is less than 512, resize the image proportionally
72
+ if min(height, width) < crop_size:
73
+ if height < width:
74
+ new_height = crop_size
75
+ new_width = int(new_height * aspect_ratio)
76
+ else:
77
+ new_width = crop_size
78
+ new_height = int(new_width / aspect_ratio)
79
+
80
+ image = cv2.resize(image, (new_width, new_height))
81
+ mask = cv2.resize(mask, (new_width, new_height))
82
+
83
+ # Find the bounding box of the mask
84
+ x, y, w, h = cv2.boundingRect(mask)
85
+
86
+ # Update the height and width of the resized image
87
+ height, width = image.shape[:2]
88
+
89
+ # # If the 512x512 square cannot cover the entire mask, resize the image accordingly
90
+ if w > crop_size or h > crop_size:
91
+ flag_padding = True
92
+ # padding to square at first
93
+ if height < width:
94
+ padding = width - height
95
+ image = np.pad(image, ((padding // 2, padding - padding // 2), (0, 0), (0, 0)), 'constant')
96
+ mask = np.pad(mask, ((padding // 2, padding - padding // 2), (0, 0)), 'constant')
97
+ padding_side = 'h'
98
+ else:
99
+ padding = height - width
100
+ image = np.pad(image, ((0, 0), (padding // 2, padding - padding // 2), (0, 0)), 'constant')
101
+ mask = np.pad(mask, ((0, 0), (padding // 2, padding - padding // 2)), 'constant')
102
+ padding_side = 'w'
103
+
104
+ resize_factor = crop_size / max(w, h)
105
+ image = cv2.resize(image, (0, 0), fx=resize_factor, fy=resize_factor)
106
+ mask = cv2.resize(mask, (0, 0), fx=resize_factor, fy=resize_factor)
107
+ x, y, w, h = cv2.boundingRect(mask)
108
+ else:
109
+ flag_padding = False
110
+
111
+ # Calculate the crop coordinates
112
+ crop_x = min(max(x + w // 2 - crop_size // 2, 0), width - crop_size)
113
+ crop_y = min(max(y + h // 2 - crop_size // 2, 0), height - crop_size)
114
+
115
+ # Fill the image
116
+ image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size] = filled_image
117
+ if flag_padding:
118
+ image = cv2.resize(image, (0, 0), fx=1/resize_factor, fy=1/resize_factor)
119
+ if padding_side == 'h':
120
+ image = image[padding // 2:padding // 2 + height_ori, :]
121
+ else:
122
+ image = image[:, padding // 2:padding // 2 + width_ori]
123
+
124
+ image = cv2.resize(image, (width_ori, height_ori))
125
+
126
+ image_copy[mask_copy==255] = image[mask_copy==255]
127
+ return image_copy
128
+
129
+
130
+ if __name__ == '__main__':
131
+
132
+ # image = cv2.imread('example/boat.jpg')
133
+ # mask = cv2.imread('example/boat_mask_2.png', cv2.IMREAD_GRAYSCALE)
134
+ image = cv2.imread('./example/groceries.jpg')
135
+ mask = cv2.imread('example/groceries_mask_2.png', cv2.IMREAD_GRAYSCALE)
136
+ # image = cv2.imread('example/bridge.jpg')
137
+ # mask = cv2.imread('example/bridge_mask_2.png', cv2.IMREAD_GRAYSCALE)
138
+ # image = cv2.imread('example/person_umbrella.jpg')
139
+ # mask = cv2.imread('example/person_umbrella_mask_2.png', cv2.IMREAD_GRAYSCALE)
140
+ # image = cv2.imread('example/hippopotamus.jpg')
141
+ # mask = cv2.imread('example/hippopotamus_mask_1.png', cv2.IMREAD_GRAYSCALE)
142
+
143
+ cropped_image, cropped_mask = crop_for_filling_pre(image, mask)
144
+ # ^ ------------------------------------------------------------------------------------
145
+ # ^ Please conduct inpainting or filling here on the cropped image with the cropped mask
146
+ # ^ ------------------------------------------------------------------------------------
147
+
148
+ # e.g.
149
+ # cropped_image[cropped_mask==255] = 0
150
+ cv2.imwrite('cropped_image.jpg', cropped_image)
151
+ cv2.imwrite('cropped_mask.jpg', cropped_mask)
152
+ print(cropped_image.shape)
153
+ print(cropped_mask.shape)
154
+
155
+ image = crop_for_filling_post(image, mask, cropped_image)
156
+ cv2.imwrite('filled_image.jpg', image)
157
+ print(image.shape)
158
+
159
+
160
+
utils/paste_object.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def paste_object(source, source_mask, target, target_coords, resize_scale=1):
5
+ assert target_coords[0] < target.shape[1] and target_coords[1] < target.shape[0]
6
+ # Find the bounding box of the source_mask
7
+ x, y, w, h = cv2.boundingRect(source_mask)
8
+ assert h < source.shape[0] and w < source.shape[1]
9
+ obj = source[y:y+h, x:x+w]
10
+ obj_msk = source_mask[y:y+h, x:x+w]
11
+ if resize_scale != 1:
12
+ obj = cv2.resize(obj, (0,0), fx=resize_scale, fy=resize_scale)
13
+ obj_msk = cv2.resize(obj_msk, (0,0), fx=resize_scale, fy=resize_scale)
14
+ _, _, w, h = cv2.boundingRect(obj_msk)
15
+
16
+ xt = max(0, target_coords[0]-w//2)
17
+ yt = max(0, target_coords[1]-h//2)
18
+ if target_coords[0]-w//2 < 0:
19
+ obj = obj[:, w//2-target_coords[0]:]
20
+ obj_msk = obj_msk[:, w//2-target_coords[0]:]
21
+ if target_coords[0]+w//2 > target.shape[1]:
22
+ obj = obj[:, :target.shape[1]-target_coords[0]+w//2]
23
+ obj_msk = obj_msk[:, :target.shape[1]-target_coords[0]+w//2]
24
+ if target_coords[1]-h//2 < 0:
25
+ obj = obj[h//2-target_coords[1]:, :]
26
+ obj_msk = obj_msk[h//2-target_coords[1]:, :]
27
+ if target_coords[1]+h//2 > target.shape[0]:
28
+ obj = obj[:target.shape[0]-target_coords[1]+h//2, :]
29
+ obj_msk = obj_msk[:target.shape[0]-target_coords[1]+h//2, :]
30
+ _, _, w, h = cv2.boundingRect(obj_msk)
31
+
32
+ target[yt:yt+h, xt:xt+w][obj_msk==255] = obj[obj_msk==255]
33
+ target_mask = np.zeros_like(target)
34
+ target_mask = cv2.cvtColor(target_mask, cv2.COLOR_BGR2GRAY)
35
+ target_mask[yt:yt+h, xt:xt+w][obj_msk==255] = 255
36
+
37
+ return target, target_mask
38
+
39
+ if __name__ == '__main__':
40
+ source = cv2.imread('example/boat.jpg')
41
+ source_mask = cv2.imread('example/boat_mask_1.png', 0)
42
+ target = cv2.imread('example/hippopotamus.jpg')
43
+ print(source.shape, source_mask.shape, target.shape)
44
+
45
+ target_coords = (700, 400) # (x, y)
46
+ resize_scale = 1
47
+ target, target_mask = paste_object(source, source_mask, target, target_coords, resize_scale)
48
+ cv2.imwrite('target_pasted.png', target)
49
+ cv2.imwrite('target_mask.png', target_mask)
50
+ print(target.shape, target_mask.shape)
utils/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Any, Dict, List
5
+
6
+
7
+ def load_img_to_array(img_p):
8
+ img = Image.open(img_p)
9
+ if img.mode == "RGBA":
10
+ img = img.convert("RGB")
11
+ return np.array(img)
12
+
13
+
14
+ def save_array_to_img(img_arr, img_p):
15
+ Image.fromarray(img_arr.astype(np.uint8)).save(img_p)
16
+
17
+
18
+ def dilate_mask(mask, dilate_factor=15):
19
+ mask = mask.astype(np.uint8)
20
+ mask = cv2.dilate(
21
+ mask,
22
+ np.ones((dilate_factor, dilate_factor), np.uint8),
23
+ iterations=1
24
+ )
25
+ return mask
26
+
27
+ def erode_mask(mask, dilate_factor=15):
28
+ mask = mask.astype(np.uint8)
29
+ mask = cv2.erode(
30
+ mask,
31
+ np.ones((dilate_factor, dilate_factor), np.uint8),
32
+ iterations=1
33
+ )
34
+ return mask
35
+
36
+ def show_mask(ax, mask: np.ndarray, random_color=False):
37
+ mask = mask.astype(np.uint8)
38
+ if np.max(mask) == 255:
39
+ mask = mask / 255
40
+ if random_color:
41
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
42
+ else:
43
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
44
+ h, w = mask.shape[-2:]
45
+ mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
46
+ ax.imshow(mask_img)
47
+
48
+
49
+ def show_points(ax, coords: List[List[float]], labels: List[int], size=375):
50
+ coords = np.array(coords)
51
+ labels = np.array(labels)
52
+ color_table = {0: 'red', 1: 'green'}
53
+ for label_value, color in color_table.items():
54
+ points = coords[labels == label_value]
55
+ ax.scatter(points[:, 0], points[:, 1], color=color, marker='*',
56
+ s=size, edgecolor='white', linewidth=1.25)
57
+
58
+ def get_clicked_point(img_path):
59
+ img = cv2.imread(img_path)
60
+ cv2.namedWindow("image")
61
+ cv2.imshow("image", img)
62
+
63
+ last_point = []
64
+ keep_looping = True
65
+
66
+ def mouse_callback(event, x, y, flags, param):
67
+ nonlocal last_point, keep_looping, img
68
+
69
+ if event == cv2.EVENT_LBUTTONDOWN:
70
+ if last_point:
71
+ cv2.circle(img, tuple(last_point), 5, (0, 0, 0), -1)
72
+ last_point = [x, y]
73
+ cv2.circle(img, tuple(last_point), 5, (0, 0, 255), -1)
74
+ cv2.imshow("image", img)
75
+ elif event == cv2.EVENT_RBUTTONDOWN:
76
+ keep_looping = False
77
+
78
+ cv2.setMouseCallback("image", mouse_callback)
79
+
80
+ while keep_looping:
81
+ cv2.waitKey(1)
82
+
83
+ cv2.destroyAllWindows()
84
+
85
+ return last_point
utils/visualize_bbox.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as patches
3
+ import numpy as np
4
+ import cv2
5
+
6
+
7
+ img = plt.imread('../example/fill-anything/sample1.png')
8
+ fig, ax = plt.subplots(1)
9
+ ax.imshow(img)
10
+
11
+ x1, y1, x2, y2 = 230, 283, 352, 407
12
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')
13
+ ax.add_patch(rect)
14
+ plt.savefig('bbox.png')
utils/visualize_mask_on_img.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from matplotlib import pyplot as plt
8
+ from typing import Any, Dict, List
9
+ import glob
10
+
11
+ from utils import load_img_to_array, show_mask
12
+
13
+
14
+ def setup_args(parser):
15
+ parser.add_argument(
16
+ "--input_img", type=str, required=True,
17
+ help="Path to a single input img",
18
+ )
19
+ parser.add_argument(
20
+ "--input_mask_glob", type=str, required=True,
21
+ help="Glob to input masks",
22
+ )
23
+ parser.add_argument(
24
+ "--output_dir", type=str, required=True,
25
+ help="Output path to the directory with results.",
26
+ )
27
+
28
+ if __name__ == "__main__":
29
+ """Example usage:
30
+ python visual_mask_on_img.py \
31
+ --input_img FA_demo/FA1_dog.png \
32
+ --input_mask_glob "results/FA1_dog/mask*.png" \
33
+ --output_dir results
34
+ """
35
+ parser = argparse.ArgumentParser()
36
+ setup_args(parser)
37
+ args = parser.parse_args(sys.argv[1:])
38
+
39
+ img = load_img_to_array(args.input_img)
40
+ img_stem = Path(args.input_img).stem
41
+
42
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
43
+
44
+ out_dir = Path(args.output_dir) / img_stem
45
+ out_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ for mask_p in mask_ps:
48
+ mask = load_img_to_array(mask_p)
49
+ mask = mask.astype(np.uint8)
50
+
51
+ # path to the results
52
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
53
+
54
+ # save the masked image
55
+ dpi = plt.rcParams['figure.dpi']
56
+ height, width = img.shape[:2]
57
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
58
+ plt.imshow(img)
59
+ plt.axis('off')
60
+ show_mask(plt.gca(), mask, random_color=False)
61
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
62
+ plt.close()
weights/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+ # placeholder