|
import argparse |
|
import cv2 |
|
from ultralytics import YOLO |
|
from FastSAM.tools import * |
|
from groundingdino.util.inference import load_model, load_image, predict, annotate, Model |
|
from torchvision.ops import box_convert |
|
import ast |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_path", type=str, default="./FastSAM/FastSAM-x.pt", help="model" |
|
) |
|
parser.add_argument( |
|
"--img_path", type=str, default="./images/dogs.jpg", help="path to image file" |
|
) |
|
parser.add_argument( |
|
"--text", type=str, default="the black dog.", help="text prompt for GroundingDINO" |
|
) |
|
parser.add_argument("--imgsz", type=int, default=1024, help="image size") |
|
parser.add_argument( |
|
"--iou", |
|
type=float, |
|
default=0.9, |
|
help="iou threshold for filtering the annotations", |
|
) |
|
parser.add_argument( |
|
"--conf", type=float, default=0.4, help="object confidence threshold" |
|
) |
|
parser.add_argument( |
|
"--output", type=str, default="./output/", help="image save path" |
|
) |
|
parser.add_argument( |
|
"--randomcolor", type=bool, default=True, help="mask random color" |
|
) |
|
parser.add_argument( |
|
"--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]" |
|
) |
|
parser.add_argument( |
|
"--point_label", |
|
type=str, |
|
default="[0]", |
|
help="[1,0] 0:background, 1:foreground", |
|
) |
|
parser.add_argument("--box_prompt", type=str, default="[0,0,0,0]", help="[x,y,w,h]") |
|
parser.add_argument( |
|
"--better_quality", |
|
type=str, |
|
default=False, |
|
help="better quality using morphologyEx", |
|
) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
parser.add_argument( |
|
"--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" |
|
) |
|
parser.add_argument( |
|
"--retina", |
|
type=bool, |
|
default=True, |
|
help="draw high-resolution segmentation masks", |
|
) |
|
parser.add_argument( |
|
"--withContours", type=bool, default=False, help="draw the edges of the masks" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def main(args): |
|
|
|
|
|
img_path = args.img_path |
|
text = args.text |
|
|
|
|
|
save_path = args.output |
|
if not os.path.exists(save_path): |
|
os.makedirs(save_path) |
|
basename = os.path.basename(args.img_path).split(".")[0] |
|
|
|
|
|
|
|
model = YOLO(args.model_path) |
|
|
|
results = model( |
|
args.img_path, |
|
imgsz=args.imgsz, |
|
device=args.device, |
|
retina_masks=args.retina, |
|
iou=args.iou, |
|
conf=args.conf, |
|
max_det=100, |
|
) |
|
|
|
|
|
|
|
groundingdino_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
groundingdino_ckpt_path = "./groundingdino_swint_ogc.pth" |
|
|
|
image_source, image = load_image(img_path) |
|
model = load_model(groundingdino_config, groundingdino_ckpt_path) |
|
|
|
boxes, logits, phrases = predict( |
|
model=model, |
|
image=image, |
|
caption=text, |
|
box_threshold=0.3, |
|
text_threshold=0.25, |
|
device=args.device, |
|
) |
|
|
|
|
|
|
|
|
|
ori_img = cv2.imread(img_path) |
|
ori_h = ori_img.shape[0] |
|
ori_w = ori_img.shape[1] |
|
|
|
|
|
boxes = boxes * torch.Tensor([ori_w, ori_h, ori_w, ori_h]) |
|
print(f"Detected Boxes: {len(boxes)}") |
|
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist() |
|
for box_idx in range(len(boxes)): |
|
mask, _ = box_prompt( |
|
results[0].masks.data, |
|
boxes[box_idx], |
|
ori_h, |
|
ori_w, |
|
) |
|
annotations = np.array([mask]) |
|
img_array = fast_process( |
|
annotations=annotations, |
|
args=args, |
|
mask_random_color=True, |
|
bbox=boxes[box_idx], |
|
) |
|
cv2.imwrite(os.path.join(save_path, basename + f"_{str(box_idx)}_caption_{phrases[box_idx]}.jpg"), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|