pg56714 commited on
Commit
c7e95b3
1 Parent(s): 7e0cd5c

Upload 11 files

Browse files
Files changed (11) hide show
  1. .gitignore +21 -0
  2. LICENSE +201 -0
  3. __init__.py +0 -0
  4. app.py +462 -0
  5. fill_anything.py +137 -0
  6. lama_inpaint.py +200 -0
  7. remove_anything.py +132 -0
  8. replace_anything.py +136 -0
  9. requirements.txt +33 -0
  10. sam_segment.py +133 -0
  11. stable_diffusion_inpaint.py +121 -0
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python byte code, etc.
2
+ __pycache__/
3
+
4
+ # C/C++ object files/libraries
5
+ *.o
6
+ *.so
7
+
8
+ # macOS
9
+ **/.DS_Store
10
+
11
+ # tmp
12
+ ~*
13
+
14
+ # pretrained_models
15
+ pretrained_models/big-lama
16
+ # pytracking/pretrain/
17
+ *.pth
18
+
19
+ results/*
20
+
21
+
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
5
+ # os.chdir("../")
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from matplotlib import pyplot as plt
11
+ import torch
12
+ import tempfile
13
+
14
+ from stable_diffusion_inpaint import fill_img_with_sd, replace_img_with_sd
15
+ from lama_inpaint import (
16
+ inpaint_img_with_lama,
17
+ build_lama_model,
18
+ inpaint_img_with_builded_lama,
19
+ )
20
+ from utils import (
21
+ load_img_to_array,
22
+ save_array_to_img,
23
+ dilate_mask,
24
+ show_mask,
25
+ show_points,
26
+ )
27
+ from PIL import Image
28
+ from segment_anything import SamPredictor, sam_model_registry
29
+ import argparse
30
+
31
+
32
+ def setup_args(parser):
33
+ parser.add_argument(
34
+ "--lama_config",
35
+ type=str,
36
+ default="./lama/configs/prediction/default.yaml",
37
+ help="The path to the config file of lama model. "
38
+ "Default: the config of big-lama",
39
+ )
40
+ parser.add_argument(
41
+ "--lama_ckpt",
42
+ type=str,
43
+ default="pretrained_models/big-lama",
44
+ help="The path to the lama checkpoint.",
45
+ )
46
+ parser.add_argument(
47
+ "--sam_ckpt",
48
+ type=str,
49
+ default="./pretrained_models/sam_vit_h_4b8939.pth",
50
+ help="The path to the SAM checkpoint to use for mask generation.",
51
+ )
52
+
53
+
54
+ def mkstemp(suffix, dir=None):
55
+ fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
56
+ os.close(fd)
57
+ return Path(path)
58
+
59
+
60
+ def get_sam_feat(img):
61
+ model["sam"].set_image(img)
62
+ features = model["sam"].features
63
+ orig_h = model["sam"].orig_h
64
+ orig_w = model["sam"].orig_w
65
+ input_h = model["sam"].input_h
66
+ input_w = model["sam"].input_w
67
+ model["sam"].reset_image()
68
+ return features, orig_h, orig_w, input_h, input_w
69
+
70
+
71
+ def get_fill_img_with_sd(image, mask, image_resolution, text_prompt):
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ if len(mask.shape) == 3:
74
+ mask = mask[:, :, 0]
75
+ np_image = np.array(image, dtype=np.uint8)
76
+ H, W, C = np_image.shape
77
+ np_image = HWC3(np_image)
78
+ np_image = resize_image(np_image, image_resolution)
79
+ mask = cv2.resize(
80
+ mask, (np_image.shape[1], np_image.shape[0]), interpolation=cv2.INTER_NEAREST
81
+ )
82
+
83
+ img_fill = fill_img_with_sd(np_image, mask, text_prompt, device=device)
84
+ img_fill = img_fill.astype(np.uint8)
85
+ return img_fill
86
+
87
+
88
+ def get_replace_img_with_sd(image, mask, image_resolution, text_prompt):
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ if len(mask.shape) == 3:
91
+ mask = mask[:, :, 0]
92
+ np_image = np.array(image, dtype=np.uint8)
93
+ H, W, C = np_image.shape
94
+ np_image = HWC3(np_image)
95
+ np_image = resize_image(np_image, image_resolution)
96
+ mask = cv2.resize(
97
+ mask, (np_image.shape[1], np_image.shape[0]), interpolation=cv2.INTER_NEAREST
98
+ )
99
+
100
+ img_replaced = replace_img_with_sd(np_image, mask, text_prompt, device=device)
101
+ img_replaced = img_replaced.astype(np.uint8)
102
+ return img_replaced
103
+
104
+
105
+ def HWC3(x):
106
+ assert x.dtype == np.uint8
107
+ if x.ndim == 2:
108
+ x = x[:, :, None]
109
+ assert x.ndim == 3
110
+ H, W, C = x.shape
111
+ assert C == 1 or C == 3 or C == 4
112
+ if C == 3:
113
+ return x
114
+ if C == 1:
115
+ return np.concatenate([x, x, x], axis=2)
116
+ if C == 4:
117
+ color = x[:, :, 0:3].astype(np.float32)
118
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
119
+ y = color * alpha + 255.0 * (1.0 - alpha)
120
+ y = y.clip(0, 255).astype(np.uint8)
121
+ return y
122
+
123
+
124
+ def resize_image(input_image, resolution):
125
+ H, W, C = input_image.shape
126
+ k = float(resolution) / min(H, W)
127
+ H = int(np.round(H * k / 64.0)) * 64
128
+ W = int(np.round(W * k / 64.0)) * 64
129
+ img = cv2.resize(
130
+ input_image,
131
+ (W, H),
132
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
133
+ )
134
+ return img
135
+
136
+
137
+ def resize_points(clicked_points, original_shape, resolution):
138
+ original_height, original_width, _ = original_shape
139
+ original_height = float(original_height)
140
+ original_width = float(original_width)
141
+
142
+ scale_factor = float(resolution) / min(original_height, original_width)
143
+ resized_points = []
144
+
145
+ for point in clicked_points:
146
+ x, y, lab = point
147
+ resized_x = int(round(x * scale_factor))
148
+ resized_y = int(round(y * scale_factor))
149
+ resized_point = (resized_x, resized_y, lab)
150
+ resized_points.append(resized_point)
151
+
152
+ return resized_points
153
+
154
+
155
+ def get_click_mask(
156
+ clicked_points, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size
157
+ ):
158
+ # model['sam'].set_image(image)
159
+ model["sam"].is_image_set = True
160
+ model["sam"].features = features
161
+ model["sam"].orig_h = orig_h
162
+ model["sam"].orig_w = orig_w
163
+ model["sam"].input_h = input_h
164
+ model["sam"].input_w = input_w
165
+
166
+ # Separate the points and labels
167
+ points, labels = zip(*[(point[:2], point[2]) for point in clicked_points])
168
+
169
+ # Convert the points and labels to numpy arrays
170
+ input_point = np.array(points)
171
+ input_label = np.array(labels)
172
+
173
+ masks, _, _ = model["sam"].predict(
174
+ point_coords=input_point,
175
+ point_labels=input_label,
176
+ multimask_output=False,
177
+ )
178
+ if dilate_kernel_size is not None:
179
+ masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
180
+ else:
181
+ masks = [mask for mask in masks]
182
+
183
+ return masks
184
+
185
+
186
+ def process_image_click(
187
+ original_image,
188
+ point_prompt,
189
+ clicked_points,
190
+ image_resolution,
191
+ features,
192
+ orig_h,
193
+ orig_w,
194
+ input_h,
195
+ input_w,
196
+ dilate_kernel_size,
197
+ evt: gr.SelectData,
198
+ ):
199
+ if clicked_points is None:
200
+ clicked_points = []
201
+
202
+ # print("Received click event:", evt)
203
+ if original_image is None:
204
+ # print("No image loaded.")
205
+ return None, clicked_points, None
206
+
207
+ clicked_coords = evt.index
208
+ if clicked_coords is None:
209
+ # print("No valid coordinates received.")
210
+ return None, clicked_points, None
211
+
212
+ x, y = clicked_coords
213
+ label = point_prompt
214
+ lab = 1 if label == "Foreground Point" else 0
215
+ clicked_points.append((x, y, lab))
216
+ # print("Updated points list:", clicked_points)
217
+
218
+ input_image = np.array(original_image, dtype=np.uint8)
219
+ H, W, C = input_image.shape
220
+ input_image = HWC3(input_image)
221
+ img = resize_image(input_image, image_resolution)
222
+ # print("Processed image size:", img.shape)
223
+
224
+ resized_points = resize_points(clicked_points, input_image.shape, image_resolution)
225
+ mask_click_np = get_click_mask(
226
+ resized_points, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size
227
+ )
228
+ mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
229
+ mask_image = HWC3(mask_click_np.astype(np.uint8))
230
+ mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
231
+ # print("Mask image prepared.")
232
+
233
+ edited_image = input_image
234
+ for x, y, lab in clicked_points:
235
+ color = (255, 0, 0) if lab == 1 else (0, 0, 255)
236
+ edited_image = cv2.circle(edited_image, (x, y), 20, color, -1)
237
+
238
+ opacity_mask = 0.75
239
+ opacity_edited = 1.0
240
+ overlay_image = cv2.addWeighted(
241
+ edited_image,
242
+ opacity_edited,
243
+ (mask_image * np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8),
244
+ opacity_mask,
245
+ 0,
246
+ )
247
+
248
+ no_mask_overlay = edited_image.copy()
249
+
250
+ return no_mask_overlay, overlay_image, clicked_points, mask_image
251
+
252
+
253
+ def image_upload(image, image_resolution):
254
+ if image is None:
255
+ return None, None, None, None, None, None
256
+ else:
257
+ np_image = np.array(image, dtype=np.uint8)
258
+ H, W, C = np_image.shape
259
+ np_image = HWC3(np_image)
260
+ np_image = resize_image(np_image, image_resolution)
261
+ features, orig_h, orig_w, input_h, input_w = get_sam_feat(np_image)
262
+ return image, features, orig_h, orig_w, input_h, input_w
263
+
264
+
265
+ def get_inpainted_img(image, mask, image_resolution):
266
+ lama_config = args.lama_config
267
+ device = "cuda" if torch.cuda.is_available() else "cpu"
268
+ if len(mask.shape) == 3:
269
+ mask = mask[:, :, 0]
270
+ img_inpainted = inpaint_img_with_builded_lama(
271
+ model["lama"], image, mask, lama_config, device=device
272
+ )
273
+ return img_inpainted
274
+
275
+
276
+ # get args
277
+ parser = argparse.ArgumentParser()
278
+ setup_args(parser)
279
+ args = parser.parse_args(sys.argv[1:])
280
+ # build models
281
+ model = {}
282
+ # build the sam model
283
+ model_type = "vit_h"
284
+ ckpt_p = args.sam_ckpt
285
+ model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
286
+ device = "cuda" if torch.cuda.is_available() else "cpu"
287
+ model_sam.to(device=device)
288
+ model["sam"] = SamPredictor(model_sam)
289
+
290
+ # build the lama model
291
+ lama_config = args.lama_config
292
+ lama_ckpt = args.lama_ckpt
293
+ device = "cuda" if torch.cuda.is_available() else "cpu"
294
+ model["lama"] = build_lama_model(lama_config, lama_ckpt, device=device)
295
+
296
+ button_size = (100, 50)
297
+ with gr.Blocks() as demo:
298
+ clicked_points = gr.State([])
299
+ # origin_image = gr.State(None)
300
+ click_mask = gr.State(None)
301
+ features = gr.State(None)
302
+ orig_h = gr.State(None)
303
+ orig_w = gr.State(None)
304
+ input_h = gr.State(None)
305
+ input_w = gr.State(None)
306
+
307
+ with gr.Row():
308
+ with gr.Column(variant="panel"):
309
+ with gr.Row():
310
+ gr.Markdown("## Upload an image and click the region you want to edit.")
311
+ with gr.Row():
312
+ source_image_click = gr.Image(
313
+ type="numpy",
314
+ interactive=True,
315
+ label="Upload and Edit Image",
316
+ )
317
+
318
+ image_edit_complete = gr.Image(
319
+ type="numpy",
320
+ interactive=False,
321
+ label="Editing Complete",
322
+ )
323
+ with gr.Row():
324
+ point_prompt = gr.Radio(
325
+ choices=["Foreground Point", "Background Point"],
326
+ value="Foreground Point",
327
+ label="Point Label",
328
+ interactive=True,
329
+ show_label=False,
330
+ )
331
+ image_resolution = gr.Slider(
332
+ label="Image Resolution",
333
+ minimum=256,
334
+ maximum=768,
335
+ value=512,
336
+ step=64,
337
+ )
338
+ dilate_kernel_size = gr.Slider(
339
+ label="Dilate Kernel Size", minimum=0, maximum=30, value=15, step=1
340
+ )
341
+ with gr.Column(variant="panel"):
342
+ with gr.Row():
343
+ gr.Markdown("## Control Panel")
344
+ text_prompt = gr.Textbox(label="Text Prompt")
345
+ lama = gr.Button("Inpaint Image", variant="primary")
346
+ fill_sd = gr.Button("Fill Anything with SD", variant="primary")
347
+ replace_sd = gr.Button("Replace Anything with SD", variant="primary")
348
+ clear_button_image = gr.Button(value="Reset", variant="secondary")
349
+
350
+ # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
351
+ with gr.Row(variant="panel"):
352
+ with gr.Column():
353
+ with gr.Row():
354
+ gr.Markdown("## Mask")
355
+ with gr.Row():
356
+ click_mask = gr.Image(
357
+ type="numpy",
358
+ label="Click Mask",
359
+ interactive=False,
360
+ )
361
+ with gr.Column():
362
+ with gr.Row():
363
+ gr.Markdown("## Image Removed with Mask")
364
+ with gr.Row():
365
+ img_rm_with_mask = gr.Image(
366
+ type="numpy",
367
+ label="Image Removed with Mask",
368
+ interactive=False,
369
+ )
370
+
371
+ with gr.Column():
372
+ with gr.Row():
373
+ gr.Markdown("## Fill Anything with Mask")
374
+ with gr.Row():
375
+ img_fill_with_mask = gr.Image(
376
+ type="numpy",
377
+ label="Image Fill Anything with Mask",
378
+ interactive=False,
379
+ )
380
+
381
+ with gr.Column():
382
+ with gr.Row():
383
+ gr.Markdown("## Replace Anything with Mask")
384
+ with gr.Row():
385
+ img_replace_with_mask = gr.Image(
386
+ type="numpy",
387
+ label="Image Replace Anything with Mask",
388
+ interactive=False,
389
+ )
390
+
391
+ source_image_click.upload(
392
+ image_upload,
393
+ inputs=[source_image_click, image_resolution],
394
+ outputs=[source_image_click, features, orig_h, orig_w, input_h, input_w],
395
+ )
396
+
397
+ source_image_click.select(
398
+ process_image_click,
399
+ inputs=[
400
+ source_image_click,
401
+ point_prompt,
402
+ clicked_points,
403
+ image_resolution,
404
+ features,
405
+ orig_h,
406
+ orig_w,
407
+ input_h,
408
+ input_w,
409
+ dilate_kernel_size,
410
+ ],
411
+ outputs=[source_image_click, image_edit_complete, clicked_points, click_mask],
412
+ show_progress=True,
413
+ queue=True,
414
+ )
415
+
416
+ lama.click(
417
+ get_inpainted_img,
418
+ inputs=[source_image_click, click_mask, image_resolution],
419
+ outputs=[img_rm_with_mask],
420
+ )
421
+
422
+ fill_sd.click(
423
+ get_fill_img_with_sd,
424
+ inputs=[source_image_click, click_mask, image_resolution, text_prompt],
425
+ outputs=[img_fill_with_mask],
426
+ )
427
+
428
+ replace_sd.click(
429
+ get_replace_img_with_sd,
430
+ inputs=[source_image_click, click_mask, image_resolution, text_prompt],
431
+ outputs=[img_replace_with_mask],
432
+ )
433
+
434
+ def reset(*args):
435
+ return [None for _ in args]
436
+
437
+ clear_button_image.click(
438
+ reset,
439
+ inputs=[
440
+ source_image_click,
441
+ image_edit_complete,
442
+ clicked_points,
443
+ click_mask,
444
+ features,
445
+ img_rm_with_mask,
446
+ img_fill_with_mask,
447
+ img_replace_with_mask,
448
+ ],
449
+ outputs=[
450
+ source_image_click,
451
+ image_edit_complete,
452
+ clicked_points,
453
+ click_mask,
454
+ features,
455
+ img_rm_with_mask,
456
+ img_fill_with_mask,
457
+ img_replace_with_mask,
458
+ ],
459
+ )
460
+
461
+ if __name__ == "__main__":
462
+ demo.launch(debug=False, show_error=True)
fill_anything.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ from pathlib import Path
7
+ from matplotlib import pyplot as plt
8
+ from typing import Any, Dict, List
9
+
10
+ from sam_segment import predict_masks_with_sam
11
+ from stable_diffusion_inpaint import fill_img_with_sd
12
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
13
+ show_mask, show_points, get_clicked_point
14
+
15
+
16
+ def setup_args(parser):
17
+ parser.add_argument(
18
+ "--input_img", type=str, required=True,
19
+ help="Path to a single input img",
20
+ )
21
+ parser.add_argument(
22
+ "--coords_type", type=str, required=True,
23
+ default="key_in", choices=["click", "key_in"],
24
+ help="The way to select coords",
25
+ )
26
+ parser.add_argument(
27
+ "--point_coords", type=float, nargs='+', required=True,
28
+ help="The coordinate of the point prompt, [coord_W coord_H].",
29
+ )
30
+ parser.add_argument(
31
+ "--point_labels", type=int, nargs='+', required=True,
32
+ help="The labels of the point prompt, 1 or 0.",
33
+ )
34
+ parser.add_argument(
35
+ "--text_prompt", type=str, required=True,
36
+ help="Text prompt",
37
+ )
38
+ parser.add_argument(
39
+ "--dilate_kernel_size", type=int, default=None,
40
+ help="Dilate kernel size. Default: None",
41
+ )
42
+ parser.add_argument(
43
+ "--output_dir", type=str, required=True,
44
+ help="Output path to the directory with results.",
45
+ )
46
+ parser.add_argument(
47
+ "--sam_model_type", type=str,
48
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b', 'vit_t'],
49
+ help="The type of sam model to load. Default: 'vit_h"
50
+ )
51
+ parser.add_argument(
52
+ "--sam_ckpt", type=str, required=True,
53
+ help="The path to the SAM checkpoint to use for mask generation.",
54
+ )
55
+ parser.add_argument(
56
+ "--seed", type=int,
57
+ help="Specify seed for reproducibility.",
58
+ )
59
+ parser.add_argument(
60
+ "--deterministic", action="store_true",
61
+ help="Use deterministic algorithms for reproducibility.",
62
+ )
63
+
64
+
65
+ if __name__ == "__main__":
66
+ """Example usage:
67
+ python fill_anything.py \
68
+ --input_img FA_demo/FA1_dog.png \
69
+ --coords_type key_in \
70
+ --point_coords 750 500 \
71
+ --point_labels 1 \
72
+ --text_prompt "a teddy bear on a bench" \
73
+ --dilate_kernel_size 15 \
74
+ --output_dir ./results \
75
+ --sam_model_type "vit_h" \
76
+ --sam_ckpt sam_vit_h_4b8939.pth
77
+ """
78
+ parser = argparse.ArgumentParser()
79
+ setup_args(parser)
80
+ args = parser.parse_args(sys.argv[1:])
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+
83
+ if args.coords_type == "click":
84
+ latest_coords = get_clicked_point(args.input_img)
85
+ elif args.coords_type == "key_in":
86
+ latest_coords = args.point_coords
87
+ img = load_img_to_array(args.input_img)
88
+
89
+ masks, _, _ = predict_masks_with_sam(
90
+ img,
91
+ [latest_coords],
92
+ args.point_labels,
93
+ model_type=args.sam_model_type,
94
+ ckpt_p=args.sam_ckpt,
95
+ device=device,
96
+ )
97
+ masks = masks.astype(np.uint8) * 255
98
+
99
+ # dilate mask to avoid unmasked edge effect
100
+ if args.dilate_kernel_size is not None:
101
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
102
+
103
+ # visualize the segmentation results
104
+ img_stem = Path(args.input_img).stem
105
+ out_dir = Path(args.output_dir) / img_stem
106
+ out_dir.mkdir(parents=True, exist_ok=True)
107
+ for idx, mask in enumerate(masks):
108
+ # path to the results
109
+ mask_p = out_dir / f"mask_{idx}.png"
110
+ img_points_p = out_dir / f"with_points.png"
111
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
112
+
113
+ # save the mask
114
+ save_array_to_img(mask, mask_p)
115
+
116
+ # save the pointed and masked image
117
+ dpi = plt.rcParams['figure.dpi']
118
+ height, width = img.shape[:2]
119
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
120
+ plt.imshow(img)
121
+ plt.axis('off')
122
+ show_points(plt.gca(), [latest_coords], args.point_labels,
123
+ size=(width*0.04)**2)
124
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
125
+ show_mask(plt.gca(), mask, random_color=False)
126
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
127
+ plt.close()
128
+
129
+ # fill the masked image
130
+ for idx, mask in enumerate(masks):
131
+ if args.seed is not None:
132
+ torch.manual_seed(args.seed)
133
+ mask_p = out_dir / f"mask_{idx}.png"
134
+ img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
135
+ img_filled = fill_img_with_sd(
136
+ img, mask, args.text_prompt, device=device)
137
+ save_array_to_img(img_filled, img_filled_p)
lama_inpaint.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ import glob
7
+ import argparse
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf
10
+ from pathlib import Path
11
+
12
+ os.environ['OMP_NUM_THREADS'] = '1'
13
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
14
+ os.environ['MKL_NUM_THREADS'] = '1'
15
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
+
18
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
19
+ from saicinpainting.evaluation.utils import move_to_device
20
+ from saicinpainting.training.trainers import load_checkpoint
21
+ from saicinpainting.evaluation.data import pad_tensor_to_modulo
22
+
23
+ from utils import load_img_to_array, save_array_to_img
24
+
25
+
26
+ @torch.no_grad()
27
+ def inpaint_img_with_lama(
28
+ img: np.ndarray,
29
+ mask: np.ndarray,
30
+ config_p: str,
31
+ ckpt_p: str,
32
+ mod=8,
33
+ device="cuda"
34
+ ):
35
+ assert len(mask.shape) == 2
36
+ if np.max(mask) == 1:
37
+ mask = mask * 255
38
+ img = torch.from_numpy(img).float().div(255.)
39
+ mask = torch.from_numpy(mask).float()
40
+ predict_config = OmegaConf.load(config_p)
41
+ predict_config.model.path = ckpt_p
42
+ # device = torch.device(predict_config.device)
43
+ device = torch.device(device)
44
+
45
+ train_config_path = os.path.join(
46
+ predict_config.model.path, 'config.yaml')
47
+
48
+ with open(train_config_path, 'r') as f:
49
+ train_config = OmegaConf.create(yaml.safe_load(f))
50
+
51
+ train_config.training_model.predict_only = True
52
+ train_config.visualizer.kind = 'noop'
53
+
54
+ checkpoint_path = os.path.join(
55
+ predict_config.model.path, 'models',
56
+ predict_config.model.checkpoint
57
+ )
58
+ model = load_checkpoint(
59
+ train_config, checkpoint_path, strict=False, map_location='cpu')
60
+ model.freeze()
61
+ if not predict_config.get('refine', False):
62
+ model.to(device)
63
+
64
+ batch = {}
65
+ batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
66
+ batch['mask'] = mask[None, None]
67
+ unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
68
+ batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
69
+ batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
70
+ batch = move_to_device(batch, device)
71
+ batch['mask'] = (batch['mask'] > 0) * 1
72
+
73
+ batch = model(batch)
74
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
75
+ cur_res = cur_res.detach().cpu().numpy()
76
+
77
+ if unpad_to_size is not None:
78
+ orig_height, orig_width = unpad_to_size
79
+ cur_res = cur_res[:orig_height, :orig_width]
80
+
81
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
82
+ return cur_res
83
+
84
+
85
+ def build_lama_model(
86
+ config_p: str,
87
+ ckpt_p: str,
88
+ device="cuda"
89
+ ):
90
+ predict_config = OmegaConf.load(config_p)
91
+ predict_config.model.path = ckpt_p
92
+ device = torch.device(device)
93
+
94
+ train_config_path = os.path.join(
95
+ predict_config.model.path, 'config.yaml')
96
+
97
+ with open(train_config_path, 'r') as f:
98
+ train_config = OmegaConf.create(yaml.safe_load(f))
99
+
100
+ train_config.training_model.predict_only = True
101
+ train_config.visualizer.kind = 'noop'
102
+
103
+ checkpoint_path = os.path.join(
104
+ predict_config.model.path, 'models',
105
+ predict_config.model.checkpoint
106
+ )
107
+ model = load_checkpoint(train_config, checkpoint_path, strict=False)
108
+ model.to(device)
109
+ model.freeze()
110
+ return model
111
+
112
+
113
+ @torch.no_grad()
114
+ def inpaint_img_with_builded_lama(
115
+ model,
116
+ img: np.ndarray,
117
+ mask: np.ndarray,
118
+ config_p=None,
119
+ mod=8,
120
+ device="cuda"
121
+ ):
122
+ assert len(mask.shape) == 2
123
+ if np.max(mask) == 1:
124
+ mask = mask * 255
125
+ img = torch.from_numpy(img).float().div(255.)
126
+ mask = torch.from_numpy(mask).float()
127
+
128
+ batch = {}
129
+ batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
130
+ batch['mask'] = mask[None, None]
131
+ unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
132
+ batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
133
+ batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
134
+ batch = move_to_device(batch, device)
135
+ batch['mask'] = (batch['mask'] > 0) * 1
136
+
137
+ batch = model(batch)
138
+ cur_res = batch["inpainted"][0].permute(1, 2, 0)
139
+ cur_res = cur_res.detach().cpu().numpy()
140
+
141
+ if unpad_to_size is not None:
142
+ orig_height, orig_width = unpad_to_size
143
+ cur_res = cur_res[:orig_height, :orig_width]
144
+
145
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
146
+ return cur_res
147
+
148
+
149
+
150
+ def setup_args(parser):
151
+ parser.add_argument(
152
+ "--input_img", type=str, required=True,
153
+ help="Path to a single input img",
154
+ )
155
+ parser.add_argument(
156
+ "--input_mask_glob", type=str, required=True,
157
+ help="Glob to input masks",
158
+ )
159
+ parser.add_argument(
160
+ "--output_dir", type=str, required=True,
161
+ help="Output path to the directory with results.",
162
+ )
163
+ parser.add_argument(
164
+ "--lama_config", type=str,
165
+ default="./lama/configs/prediction/default.yaml",
166
+ help="The path to the config file of lama model. "
167
+ "Default: the config of big-lama",
168
+ )
169
+ parser.add_argument(
170
+ "--lama_ckpt", type=str, required=True,
171
+ help="The path to the lama checkpoint.",
172
+ )
173
+
174
+
175
+ if __name__ == "__main__":
176
+ """Example usage:
177
+ python lama_inpaint.py \
178
+ --input_img FA_demo/FA1_dog.png \
179
+ --input_mask_glob "results/FA1_dog/mask*.png" \
180
+ --output_dir results \
181
+ --lama_config lama/configs/prediction/default.yaml \
182
+ --lama_ckpt big-lama
183
+ """
184
+ parser = argparse.ArgumentParser()
185
+ setup_args(parser)
186
+ args = parser.parse_args(sys.argv[1:])
187
+ device = "cuda" if torch.cuda.is_available() else "cpu"
188
+
189
+ img_stem = Path(args.input_img).stem
190
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
191
+ out_dir = Path(args.output_dir) / img_stem
192
+ out_dir.mkdir(parents=True, exist_ok=True)
193
+
194
+ img = load_img_to_array(args.input_img)
195
+ for mask_p in mask_ps:
196
+ mask = load_img_to_array(mask_p)
197
+ img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
198
+ img_inpainted = inpaint_img_with_lama(
199
+ img, mask, args.lama_config, args.lama_ckpt, device=device)
200
+ save_array_to_img(img_inpainted, img_inpainted_p)
remove_anything.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from matplotlib import pyplot as plt
7
+
8
+ from sam_segment import predict_masks_with_sam
9
+ from lama_inpaint import inpaint_img_with_lama
10
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
+ show_mask, show_points, get_clicked_point
12
+
13
+
14
+ def setup_args(parser):
15
+ parser.add_argument(
16
+ "--input_img", type=str, required=True,
17
+ help="Path to a single input img",
18
+ )
19
+ parser.add_argument(
20
+ "--coords_type", type=str, required=True,
21
+ default="key_in", choices=["click", "key_in"],
22
+ help="The way to select coords",
23
+ )
24
+ parser.add_argument(
25
+ "--point_coords", type=float, nargs='+', required=True,
26
+ help="The coordinate of the point prompt, [coord_W coord_H].",
27
+ )
28
+ parser.add_argument(
29
+ "--point_labels", type=int, nargs='+', required=True,
30
+ help="The labels of the point prompt, 1 or 0.",
31
+ )
32
+ parser.add_argument(
33
+ "--dilate_kernel_size", type=int, default=None,
34
+ help="Dilate kernel size. Default: None",
35
+ )
36
+ parser.add_argument(
37
+ "--output_dir", type=str, required=True,
38
+ help="Output path to the directory with results.",
39
+ )
40
+ parser.add_argument(
41
+ "--sam_model_type", type=str,
42
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b', 'vit_t'],
43
+ help="The type of sam model to load. Default: 'vit_h"
44
+ )
45
+ parser.add_argument(
46
+ "--sam_ckpt", type=str, required=True,
47
+ help="The path to the SAM checkpoint to use for mask generation.",
48
+ )
49
+ parser.add_argument(
50
+ "--lama_config", type=str,
51
+ default="./lama/configs/prediction/default.yaml",
52
+ help="The path to the config file of lama model. "
53
+ "Default: the config of big-lama",
54
+ )
55
+ parser.add_argument(
56
+ "--lama_ckpt", type=str, required=True,
57
+ help="The path to the lama checkpoint.",
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ """Example usage:
63
+ python remove_anything.py \
64
+ --input_img FA_demo/FA1_dog.png \
65
+ --coords_type key_in \
66
+ --point_coords 750 500 \
67
+ --point_labels 1 \
68
+ --dilate_kernel_size 15 \
69
+ --output_dir ./results \
70
+ --sam_model_type "vit_h" \
71
+ --sam_ckpt sam_vit_h_4b8939.pth \
72
+ --lama_config lama/configs/prediction/default.yaml \
73
+ --lama_ckpt big-lama
74
+ """
75
+ parser = argparse.ArgumentParser()
76
+ setup_args(parser)
77
+ args = parser.parse_args(sys.argv[1:])
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+
80
+ if args.coords_type == "click":
81
+ latest_coords = get_clicked_point(args.input_img)
82
+ elif args.coords_type == "key_in":
83
+ latest_coords = args.point_coords
84
+ img = load_img_to_array(args.input_img)
85
+
86
+ masks, _, _ = predict_masks_with_sam(
87
+ img,
88
+ [latest_coords],
89
+ args.point_labels,
90
+ model_type=args.sam_model_type,
91
+ ckpt_p=args.sam_ckpt,
92
+ device=device,
93
+ )
94
+ masks = masks.astype(np.uint8) * 255
95
+
96
+ # dilate mask to avoid unmasked edge effect
97
+ if args.dilate_kernel_size is not None:
98
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
99
+
100
+ # visualize the segmentation results
101
+ img_stem = Path(args.input_img).stem
102
+ out_dir = Path(args.output_dir) / img_stem
103
+ out_dir.mkdir(parents=True, exist_ok=True)
104
+ for idx, mask in enumerate(masks):
105
+ # path to the results
106
+ mask_p = out_dir / f"mask_{idx}.png"
107
+ img_points_p = out_dir / f"with_points.png"
108
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
109
+
110
+ # save the mask
111
+ save_array_to_img(mask, mask_p)
112
+
113
+ # save the pointed and masked image
114
+ dpi = plt.rcParams['figure.dpi']
115
+ height, width = img.shape[:2]
116
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
117
+ plt.imshow(img)
118
+ plt.axis('off')
119
+ show_points(plt.gca(), [latest_coords], args.point_labels,
120
+ size=(width*0.04)**2)
121
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
122
+ show_mask(plt.gca(), mask, random_color=False)
123
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
124
+ plt.close()
125
+
126
+ # inpaint the masked image
127
+ for idx, mask in enumerate(masks):
128
+ mask_p = out_dir / f"mask_{idx}.png"
129
+ img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
130
+ img_inpainted = inpaint_img_with_lama(
131
+ img, mask, args.lama_config, args.lama_ckpt, device=device)
132
+ save_array_to_img(img_inpainted, img_inpainted_p)
replace_anything.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ from pathlib import Path
7
+ from matplotlib import pyplot as plt
8
+ from typing import Any, Dict, List
9
+ from sam_segment import predict_masks_with_sam
10
+ from stable_diffusion_inpaint import replace_img_with_sd
11
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
12
+ show_mask, show_points, get_clicked_point
13
+
14
+
15
+ def setup_args(parser):
16
+ parser.add_argument(
17
+ "--input_img", type=str, required=True,
18
+ help="Path to a single input img",
19
+ )
20
+ parser.add_argument(
21
+ "--coords_type", type=str, required=True,
22
+ default="key_in", choices=["click", "key_in"],
23
+ help="The way to select coords",
24
+ )
25
+ parser.add_argument(
26
+ "--point_coords", type=float, nargs='+', required=True,
27
+ help="The coordinate of the point prompt, [coord_W coord_H].",
28
+ )
29
+ parser.add_argument(
30
+ "--point_labels", type=int, nargs='+', required=True,
31
+ help="The labels of the point prompt, 1 or 0.",
32
+ )
33
+ parser.add_argument(
34
+ "--text_prompt", type=str, required=True,
35
+ help="Text prompt",
36
+ )
37
+ parser.add_argument(
38
+ "--dilate_kernel_size", type=int, default=None,
39
+ help="Dilate kernel size. Default: None",
40
+ )
41
+ parser.add_argument(
42
+ "--output_dir", type=str, required=True,
43
+ help="Output path to the directory with results.",
44
+ )
45
+ parser.add_argument(
46
+ "--sam_model_type", type=str,
47
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b', 'vit_t'],
48
+ help="The type of sam model to load. Default: 'vit_h"
49
+ )
50
+ parser.add_argument(
51
+ "--sam_ckpt", type=str, required=True,
52
+ help="The path to the SAM checkpoint to use for mask generation.",
53
+ )
54
+ parser.add_argument(
55
+ "--seed", type=int,
56
+ help="Specify seed for reproducibility.",
57
+ )
58
+ parser.add_argument(
59
+ "--deterministic", action="store_true",
60
+ help="Use deterministic algorithms for reproducibility.",
61
+ )
62
+
63
+
64
+
65
+ if __name__ == "__main__":
66
+ """Example usage:
67
+ python replace_anything.py \
68
+ --input_img ./example/replace-anything/dog.png \
69
+ --coords_type key_in \
70
+ --point_coords 750 500 \
71
+ --point_labels 1 \
72
+ --text_prompt "sit on the swing" \
73
+ --output_dir ./results \
74
+ --sam_model_type "vit_h" \
75
+ --sam_ckpt ./pretrained_models/sam_vit_h_4b8939.pth
76
+ """
77
+ parser = argparse.ArgumentParser()
78
+ setup_args(parser)
79
+ args = parser.parse_args(sys.argv[1:])
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+
82
+ if args.coords_type == "click":
83
+ latest_coords = get_clicked_point(args.input_img)
84
+ elif args.coords_type == "key_in":
85
+ latest_coords = args.point_coords
86
+ img = load_img_to_array(args.input_img)
87
+
88
+ masks, _, _ = predict_masks_with_sam(
89
+ img,
90
+ [latest_coords],
91
+ args.point_labels,
92
+ model_type=args.sam_model_type,
93
+ ckpt_p=args.sam_ckpt,
94
+ device=device,
95
+ )
96
+ masks = masks.astype(np.uint8) * 255
97
+
98
+ # dilate mask to avoid unmasked edge effect
99
+ if args.dilate_kernel_size is not None:
100
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
101
+
102
+ # visualize the segmentation results
103
+ img_stem = Path(args.input_img).stem
104
+ out_dir = Path(args.output_dir) / img_stem
105
+ out_dir.mkdir(parents=True, exist_ok=True)
106
+ for idx, mask in enumerate(masks):
107
+ # path to the results
108
+ mask_p = out_dir / f"mask_{idx}.png"
109
+ img_points_p = out_dir / f"with_points.png"
110
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
111
+
112
+ # save the mask
113
+ save_array_to_img(mask, mask_p)
114
+
115
+ # save the pointed and masked image
116
+ dpi = plt.rcParams['figure.dpi']
117
+ height, width = img.shape[:2]
118
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
119
+ plt.imshow(img)
120
+ plt.axis('off')
121
+ show_points(plt.gca(), [latest_coords], args.point_labels,
122
+ size=(width*0.04)**2)
123
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
124
+ show_mask(plt.gca(), mask, random_color=False)
125
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
126
+ plt.close()
127
+
128
+ # fill the masked image
129
+ for idx, mask in enumerate(masks):
130
+ if args.seed is not None:
131
+ torch.manual_seed(args.seed)
132
+ mask_p = out_dir / f"mask_{idx}.png"
133
+ img_replaced_p = out_dir / f"replaced_with_{Path(mask_p).name}"
134
+ img_replaced = replace_img_with_sd(
135
+ img, mask, args.text_prompt, device=device)
136
+ save_array_to_img(img_replaced, img_replaced_p)
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+
3
+ torch
4
+ torchvision
5
+ torchaudio
6
+ segment_anything
7
+ diffusers
8
+ transformers
9
+ accelerate
10
+ scipy
11
+ safetensors
12
+
13
+ # lama
14
+ pyyaml
15
+ tqdm
16
+ numpy
17
+ easydict==1.9.0
18
+ scikit-image==0.17.2
19
+ scikit-learn==0.24.2
20
+ opencv-python
21
+ tensorflow
22
+ joblib
23
+ matplotlib
24
+ pandas
25
+ albumentations==0.5.2
26
+ hydra-core==1.1.0
27
+ pytorch-lightning==1.2.9
28
+ tabulate
29
+ kornia==0.5.0
30
+ webdataset
31
+ packaging
32
+ scikit-learn==0.24.2
33
+ wldhx.yadisk-direct
sam_segment.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from matplotlib import pyplot as plt
6
+ from typing import Any, Dict, List
7
+ import torch
8
+
9
+ from segment_anything import SamPredictor, sam_model_registry
10
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
+ show_mask, show_points
12
+
13
+
14
+ def predict_masks_with_sam(
15
+ img: np.ndarray,
16
+ point_coords: List[List[float]],
17
+ point_labels: List[int],
18
+ model_type: str,
19
+ ckpt_p: str,
20
+ device="cuda"
21
+ ):
22
+ point_coords = np.array(point_coords)
23
+ point_labels = np.array(point_labels)
24
+ sam = sam_model_registry[model_type](checkpoint=ckpt_p)
25
+ sam.to(device=device)
26
+ predictor = SamPredictor(sam)
27
+
28
+ predictor.set_image(img)
29
+ masks, scores, logits = predictor.predict(
30
+ point_coords=point_coords,
31
+ point_labels=point_labels,
32
+ multimask_output=True,
33
+ )
34
+ return masks, scores, logits
35
+
36
+
37
+ def build_sam_model(model_type: str, ckpt_p: str, device="cuda"):
38
+ sam = sam_model_registry[model_type](checkpoint=ckpt_p)
39
+ sam.to(device=device)
40
+ predictor = SamPredictor(sam)
41
+ return predictor
42
+
43
+
44
+
45
+ def setup_args(parser):
46
+ parser.add_argument(
47
+ "--input_img", type=str, required=True,
48
+ help="Path to a single input img",
49
+ )
50
+ parser.add_argument(
51
+ "--point_coords", type=float, nargs='+', required=True,
52
+ help="The coordinate of the point prompt, [coord_W coord_H].",
53
+ )
54
+ parser.add_argument(
55
+ "--point_labels", type=int, nargs='+', required=True,
56
+ help="The labels of the point prompt, 1 or 0.",
57
+ )
58
+ parser.add_argument(
59
+ "--dilate_kernel_size", type=int, default=None,
60
+ help="Dilate kernel size. Default: None",
61
+ )
62
+ parser.add_argument(
63
+ "--output_dir", type=str, required=True,
64
+ help="Output path to the directory with results.",
65
+ )
66
+ parser.add_argument(
67
+ "--sam_model_type", type=str,
68
+ default="vit_h", choices=['vit_h', 'vit_l', 'vit_b'],
69
+ help="The type of sam model to load. Default: 'vit_h"
70
+ )
71
+ parser.add_argument(
72
+ "--sam_ckpt", type=str, required=True,
73
+ help="The path to the SAM checkpoint to use for mask generation.",
74
+ )
75
+
76
+
77
+ if __name__ == "__main__":
78
+ """Example usage:
79
+ python sam_segment.py \
80
+ --input_img FA_demo/FA1_dog.png \
81
+ --point_coords 750 500 \
82
+ --point_labels 1 \
83
+ --dilate_kernel_size 15 \
84
+ --output_dir ./results \
85
+ --sam_model_type "vit_h" \
86
+ --sam_ckpt sam_vit_h_4b8939.pth
87
+ """
88
+ parser = argparse.ArgumentParser()
89
+ setup_args(parser)
90
+ args = parser.parse_args(sys.argv[1:])
91
+ device = "cuda" if torch.cuda.is_available() else "cpu"
92
+
93
+ img = load_img_to_array(args.input_img)
94
+
95
+ masks, _, _ = predict_masks_with_sam(
96
+ img,
97
+ [args.point_coords],
98
+ args.point_labels,
99
+ model_type=args.sam_model_type,
100
+ ckpt_p=args.sam_ckpt,
101
+ device=device,
102
+ )
103
+ masks = masks.astype(np.uint8) * 255
104
+
105
+ # dilate mask to avoid unmasked edge effect
106
+ if args.dilate_kernel_size is not None:
107
+ masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
108
+
109
+ # visualize the segmentation results
110
+ img_stem = Path(args.input_img).stem
111
+ out_dir = Path(args.output_dir) / img_stem
112
+ out_dir.mkdir(parents=True, exist_ok=True)
113
+ for idx, mask in enumerate(masks):
114
+ # path to the results
115
+ mask_p = out_dir / f"mask_{idx}.png"
116
+ img_points_p = out_dir / f"with_points.png"
117
+ img_mask_p = out_dir / f"with_{Path(mask_p).name}"
118
+
119
+ # save the mask
120
+ save_array_to_img(mask, mask_p)
121
+
122
+ # save the pointed and masked image
123
+ dpi = plt.rcParams['figure.dpi']
124
+ height, width = img.shape[:2]
125
+ plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
126
+ plt.imshow(img)
127
+ plt.axis('off')
128
+ show_points(plt.gca(), [args.point_coords], args.point_labels,
129
+ size=(width*0.04)**2)
130
+ plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
131
+ show_mask(plt.gca(), mask, random_color=False)
132
+ plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
133
+ plt.close()
stable_diffusion_inpaint.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import argparse
5
+ import torch
6
+ import numpy as np
7
+ import PIL.Image as Image
8
+ from pathlib import Path
9
+ from diffusers import StableDiffusionInpaintPipeline
10
+ from utils.mask_processing import crop_for_filling_pre, crop_for_filling_post
11
+ from utils.crop_for_replacing import recover_size, resize_and_pad
12
+ from utils import load_img_to_array, save_array_to_img
13
+
14
+
15
+ def fill_img_with_sd(
16
+ img: np.ndarray, mask: np.ndarray, text_prompt: str, device="cuda"
17
+ ):
18
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
19
+ "stabilityai/stable-diffusion-2-inpainting",
20
+ torch_dtype=torch.float32,
21
+ ).to(device)
22
+ img_crop, mask_crop = crop_for_filling_pre(img, mask)
23
+ img_crop_filled = pipe(
24
+ prompt=text_prompt,
25
+ image=Image.fromarray(img_crop),
26
+ mask_image=Image.fromarray(mask_crop),
27
+ ).images[0]
28
+ img_filled = crop_for_filling_post(img, mask, np.array(img_crop_filled))
29
+ return img_filled
30
+
31
+
32
+ def replace_img_with_sd(
33
+ img: np.ndarray, mask: np.ndarray, text_prompt: str, step: int = 50, device="cuda"
34
+ ):
35
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
36
+ "stabilityai/stable-diffusion-2-inpainting",
37
+ torch_dtype=torch.float32,
38
+ ).to(device)
39
+ img_padded, mask_padded, padding_factors = resize_and_pad(img, mask)
40
+ img_padded = pipe(
41
+ prompt=text_prompt,
42
+ image=Image.fromarray(img_padded),
43
+ mask_image=Image.fromarray(255 - mask_padded),
44
+ num_inference_steps=step,
45
+ ).images[0]
46
+ height, width, _ = img.shape
47
+ img_resized, mask_resized = recover_size(
48
+ np.array(img_padded), mask_padded, (height, width), padding_factors
49
+ )
50
+ mask_resized = np.expand_dims(mask_resized, -1) / 255
51
+ img_resized = img_resized * (1 - mask_resized) + img * mask_resized
52
+ return img_resized
53
+
54
+
55
+ def setup_args(parser):
56
+ parser.add_argument(
57
+ "--input_img",
58
+ type=str,
59
+ required=True,
60
+ help="Path to a single input img",
61
+ )
62
+ parser.add_argument(
63
+ "--text_prompt",
64
+ type=str,
65
+ required=True,
66
+ help="Text prompt",
67
+ )
68
+ parser.add_argument(
69
+ "--input_mask_glob",
70
+ type=str,
71
+ required=True,
72
+ help="Glob to input masks",
73
+ )
74
+ parser.add_argument(
75
+ "--output_dir",
76
+ type=str,
77
+ required=True,
78
+ help="Output path to the directory with results.",
79
+ )
80
+ parser.add_argument(
81
+ "--seed",
82
+ type=int,
83
+ help="Specify seed for reproducibility.",
84
+ )
85
+ parser.add_argument(
86
+ "--deterministic",
87
+ action="store_true",
88
+ help="Use deterministic algorithms for reproducibility.",
89
+ )
90
+
91
+
92
+ if __name__ == "__main__":
93
+ """Example usage:
94
+ python lama_inpaint.py \
95
+ --input_img FA_demo/FA1_dog.png \
96
+ --input_mask_glob "results/FA1_dog/mask*.png" \
97
+ --text_prompt "a teddy bear on a bench" \
98
+ --output_dir results
99
+ """
100
+ parser = argparse.ArgumentParser()
101
+ setup_args(parser)
102
+ args = parser.parse_args(sys.argv[1:])
103
+ device = "cuda" if torch.cuda.is_available() else "cpu"
104
+
105
+ if args.deterministic:
106
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
107
+ torch.use_deterministic_algorithms(True)
108
+
109
+ img_stem = Path(args.input_img).stem
110
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
111
+ out_dir = Path(args.output_dir) / img_stem
112
+ out_dir.mkdir(parents=True, exist_ok=True)
113
+
114
+ img = load_img_to_array(args.input_img)
115
+ for mask_p in mask_ps:
116
+ if args.seed is not None:
117
+ torch.manual_seed(args.seed)
118
+ mask = load_img_to_array(mask_p)
119
+ img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
120
+ img_filled = fill_img_with_sd(img, mask, args.text_prompt, device=device)
121
+ save_array_to_img(img_filled, img_filled_p)