Inpaint-Anything / remove_anything.py
pg56714's picture
Upload 11 files
c7e95b3 verified
import torch
import sys
import argparse
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
from sam_segment import predict_masks_with_sam
from lama_inpaint import inpaint_img_with_lama
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
show_mask, show_points, get_clicked_point
def setup_args(parser):
parser.add_argument(
"--input_img", type=str, required=True,
help="Path to a single input img",
)
parser.add_argument(
"--coords_type", type=str, required=True,
default="key_in", choices=["click", "key_in"],
help="The way to select coords",
)
parser.add_argument(
"--point_coords", type=float, nargs='+', required=True,
help="The coordinate of the point prompt, [coord_W coord_H].",
)
parser.add_argument(
"--point_labels", type=int, nargs='+', required=True,
help="The labels of the point prompt, 1 or 0.",
)
parser.add_argument(
"--dilate_kernel_size", type=int, default=None,
help="Dilate kernel size. Default: None",
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Output path to the directory with results.",
)
parser.add_argument(
"--sam_model_type", type=str,
default="vit_h", choices=['vit_h', 'vit_l', 'vit_b', 'vit_t'],
help="The type of sam model to load. Default: 'vit_h"
)
parser.add_argument(
"--sam_ckpt", type=str, required=True,
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument(
"--lama_config", type=str,
default="./lama/configs/prediction/default.yaml",
help="The path to the config file of lama model. "
"Default: the config of big-lama",
)
parser.add_argument(
"--lama_ckpt", type=str, required=True,
help="The path to the lama checkpoint.",
)
if __name__ == "__main__":
"""Example usage:
python remove_anything.py \
--input_img FA_demo/FA1_dog.png \
--coords_type key_in \
--point_coords 750 500 \
--point_labels 1 \
--dilate_kernel_size 15 \
--output_dir ./results \
--sam_model_type "vit_h" \
--sam_ckpt sam_vit_h_4b8939.pth \
--lama_config lama/configs/prediction/default.yaml \
--lama_ckpt big-lama
"""
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.coords_type == "click":
latest_coords = get_clicked_point(args.input_img)
elif args.coords_type == "key_in":
latest_coords = args.point_coords
img = load_img_to_array(args.input_img)
masks, _, _ = predict_masks_with_sam(
img,
[latest_coords],
args.point_labels,
model_type=args.sam_model_type,
ckpt_p=args.sam_ckpt,
device=device,
)
masks = masks.astype(np.uint8) * 255
# dilate mask to avoid unmasked edge effect
if args.dilate_kernel_size is not None:
masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]
# visualize the segmentation results
img_stem = Path(args.input_img).stem
out_dir = Path(args.output_dir) / img_stem
out_dir.mkdir(parents=True, exist_ok=True)
for idx, mask in enumerate(masks):
# path to the results
mask_p = out_dir / f"mask_{idx}.png"
img_points_p = out_dir / f"with_points.png"
img_mask_p = out_dir / f"with_{Path(mask_p).name}"
# save the mask
save_array_to_img(mask, mask_p)
# save the pointed and masked image
dpi = plt.rcParams['figure.dpi']
height, width = img.shape[:2]
plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
plt.imshow(img)
plt.axis('off')
show_points(plt.gca(), [latest_coords], args.point_labels,
size=(width*0.04)**2)
plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)
show_mask(plt.gca(), mask, random_color=False)
plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
plt.close()
# inpaint the masked image
for idx, mask in enumerate(masks):
mask_p = out_dir / f"mask_{idx}.png"
img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
img_inpainted = inpaint_img_with_lama(
img, mask, args.lama_config, args.lama_ckpt, device=device)
save_array_to_img(img_inpainted, img_inpainted_p)