Spaces:
Runtime error
Runtime error
Add remove
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ import subprocess
|
|
11 |
import copy
|
12 |
import time
|
13 |
import warnings
|
|
|
|
|
14 |
|
15 |
import torch
|
16 |
from torchvision.ops import box_convert
|
@@ -26,13 +28,18 @@ import groundingdino.datasets.transforms as T
|
|
26 |
# segment anything
|
27 |
from segment_anything import build_sam, SamPredictor
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
#stable diffusion
|
30 |
from diffusers import StableDiffusionInpaintPipeline
|
31 |
|
32 |
from huggingface_hub import hf_hub_download
|
33 |
|
34 |
-
if not os.path.exists('./
|
35 |
-
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/
|
36 |
|
37 |
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
38 |
logger.info(f"get sam_vit_h_4b8939.pth...")
|
@@ -177,6 +184,63 @@ def mix_masks(imgs):
|
|
177 |
re_img = 1 - re_img
|
178 |
return Image.fromarray(np.uint8(255*re_img))
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
181 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
182 |
|
@@ -199,6 +263,8 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
199 |
# load image
|
200 |
image_pil, image_tensor = load_image_and_transform(input_image['image'])
|
201 |
|
|
|
|
|
202 |
# RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
|
203 |
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
204 |
pass
|
@@ -218,7 +284,6 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
218 |
}
|
219 |
|
220 |
# store and save DINO output
|
221 |
-
output_images = []
|
222 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
223 |
image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
|
224 |
image_with_box.save(image_path)
|
@@ -300,7 +365,39 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
300 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
301 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
302 |
image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
306 |
output_images.append(image_inpainting)
|
@@ -330,6 +427,7 @@ def change_radio_display(task_type, mask_source_radio):
|
|
330 |
# model initialization
|
331 |
groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
|
332 |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
|
|
333 |
|
334 |
# initialize stable-diffusion-inpainting
|
335 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
@@ -359,7 +457,7 @@ if __name__ == "__main__":
|
|
359 |
with gr.Row():
|
360 |
with gr.Column():
|
361 |
input_image = gr.Image(
|
362 |
-
source="upload", elem_id="image_upload", type="pil", tool="sketch", value="
|
363 |
task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
|
364 |
label='Task type', visible=True)
|
365 |
|
@@ -368,7 +466,7 @@ if __name__ == "__main__":
|
|
368 |
visible=False)
|
369 |
|
370 |
text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
|
371 |
-
value='
|
372 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
373 |
|
374 |
run_button = gr.Button(label="Run")
|
|
|
11 |
import copy
|
12 |
import time
|
13 |
import warnings
|
14 |
+
import io
|
15 |
+
import random
|
16 |
|
17 |
import torch
|
18 |
from torchvision.ops import box_convert
|
|
|
28 |
# segment anything
|
29 |
from segment_anything import build_sam, SamPredictor
|
30 |
|
31 |
+
# lama-cleaner
|
32 |
+
from lama_cleaner.model_manager import ModelManager
|
33 |
+
from lama_cleaner.schema import Config as lama_Config
|
34 |
+
from lama_cleaner.helper import load_img, numpy_to_bytes, resize_max_size
|
35 |
+
|
36 |
#stable diffusion
|
37 |
from diffusers import StableDiffusionInpaintPipeline
|
38 |
|
39 |
from huggingface_hub import hf_hub_download
|
40 |
|
41 |
+
if not os.path.exists('./demo2.jpg'):
|
42 |
+
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo2.jpg")
|
43 |
|
44 |
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
45 |
logger.info(f"get sam_vit_h_4b8939.pth...")
|
|
|
184 |
re_img = 1 - re_img
|
185 |
return Image.fromarray(np.uint8(255*re_img))
|
186 |
|
187 |
+
def lama_cleaner_process(image, mask):
|
188 |
+
ori_image = image
|
189 |
+
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
190 |
+
# rotate image
|
191 |
+
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
192 |
+
image = ori_image
|
193 |
+
|
194 |
+
original_shape = ori_image.shape
|
195 |
+
interpolation = cv2.INTER_CUBIC
|
196 |
+
|
197 |
+
size_limit = 1080
|
198 |
+
if size_limit == "Original":
|
199 |
+
size_limit = max(image.shape)
|
200 |
+
else:
|
201 |
+
size_limit = int(size_limit)
|
202 |
+
|
203 |
+
config = lama_Config(
|
204 |
+
ldm_steps=25,
|
205 |
+
ldm_sampler='plms',
|
206 |
+
zits_wireframe=True,
|
207 |
+
hd_strategy='Original',
|
208 |
+
hd_strategy_crop_margin=196,
|
209 |
+
hd_strategy_crop_trigger_size=1280,
|
210 |
+
hd_strategy_resize_limit=2048,
|
211 |
+
prompt='',
|
212 |
+
use_croper=False,
|
213 |
+
croper_x=0,
|
214 |
+
croper_y=0,
|
215 |
+
croper_height=512,
|
216 |
+
croper_width=512,
|
217 |
+
sd_mask_blur=5,
|
218 |
+
sd_strength=0.75,
|
219 |
+
sd_steps=50,
|
220 |
+
sd_guidance_scale=7.5,
|
221 |
+
sd_sampler='ddim',
|
222 |
+
sd_seed=42,
|
223 |
+
cv2_flag='INPAINT_NS',
|
224 |
+
cv2_radius=5,
|
225 |
+
)
|
226 |
+
|
227 |
+
if config.sd_seed == -1:
|
228 |
+
config.sd_seed = random.randint(1, 999999999)
|
229 |
+
|
230 |
+
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
231 |
+
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
232 |
+
# logger.info(f"Resized image shape_1_: {image.shape}")
|
233 |
+
|
234 |
+
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
235 |
+
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
236 |
+
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
237 |
+
|
238 |
+
res_np_img = lama_cleaner_model(image, mask, config)
|
239 |
+
torch.cuda.empty_cache()
|
240 |
+
|
241 |
+
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
242 |
+
return image
|
243 |
+
|
244 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
245 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
246 |
|
|
|
263 |
# load image
|
264 |
image_pil, image_tensor = load_image_and_transform(input_image['image'])
|
265 |
|
266 |
+
output_images = []
|
267 |
+
output_images.append(input_image['image'])
|
268 |
# RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
|
269 |
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
270 |
pass
|
|
|
284 |
}
|
285 |
|
286 |
# store and save DINO output
|
|
|
287 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
288 |
image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
|
289 |
image_with_box.save(image_path)
|
|
|
365 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
366 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
367 |
image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
368 |
+
else:
|
369 |
+
# remove from mask
|
370 |
+
if mask_source_radio == mask_source_segment:
|
371 |
+
mask_imgs = []
|
372 |
+
masks_shape = masks_ori.shape
|
373 |
+
boxes_filt_ori_array = boxes_filt_ori.numpy()
|
374 |
+
if inpaint_mode == 'merge':
|
375 |
+
extend_shape_0 = masks_shape[0]
|
376 |
+
extend_shape_1 = masks_shape[1]
|
377 |
+
else:
|
378 |
+
extend_shape_0 = 1
|
379 |
+
extend_shape_1 = 1
|
380 |
+
for i in range(extend_shape_0):
|
381 |
+
for j in range(extend_shape_1):
|
382 |
+
mask = masks_ori[i][j].cpu().numpy()
|
383 |
+
mask_pil = Image.fromarray(mask)
|
384 |
+
|
385 |
+
if remove_mode == 'segment':
|
386 |
+
useRectangle = False
|
387 |
+
else:
|
388 |
+
useRectangle = True
|
389 |
+
|
390 |
+
try:
|
391 |
+
remove_mask_extend = int(remove_mask_extend)
|
392 |
+
except:
|
393 |
+
remove_mask_extend = 10
|
394 |
+
mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
|
395 |
+
box_convert(torch.tensor(boxes_filt_ori_array[i]), in_fmt="cxcywh", out_fmt="xyxy").numpy(),
|
396 |
+
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
397 |
+
mask_imgs.append(mask_pil_exp)
|
398 |
+
mask_pil = mix_masks(mask_imgs)
|
399 |
+
output_images.append(mask_pil.convert("RGB"))
|
400 |
+
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
|
401 |
|
402 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
403 |
output_images.append(image_inpainting)
|
|
|
427 |
# model initialization
|
428 |
groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
|
429 |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
430 |
+
lama_cleaner_model = ModelManager(name='lama',device='cpu')
|
431 |
|
432 |
# initialize stable-diffusion-inpainting
|
433 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
|
|
457 |
with gr.Row():
|
458 |
with gr.Column():
|
459 |
input_image = gr.Image(
|
460 |
+
source="upload", elem_id="image_upload", type="pil", tool="sketch", value="demo2.jpg", label="Upload")
|
461 |
task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
|
462 |
label='Task type', visible=True)
|
463 |
|
|
|
466 |
visible=False)
|
467 |
|
468 |
text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
|
469 |
+
value='dog', placeholder="Cannot be empty")
|
470 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
471 |
|
472 |
run_button = gr.Button(label="Run")
|