eri2 / EfficientSAM /grounded_fast_sam.py
mart9992's picture
m
aede1d5
import argparse
import cv2
from ultralytics import YOLO
from FastSAM.tools import *
from groundingdino.util.inference import load_model, load_image, predict, annotate, Model
from torchvision.ops import box_convert
import ast
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", type=str, default="./FastSAM/FastSAM-x.pt", help="model"
)
parser.add_argument(
"--img_path", type=str, default="./images/dogs.jpg", help="path to image file"
)
parser.add_argument(
"--text", type=str, default="the black dog.", help="text prompt for GroundingDINO"
)
parser.add_argument("--imgsz", type=int, default=1024, help="image size")
parser.add_argument(
"--iou",
type=float,
default=0.9,
help="iou threshold for filtering the annotations",
)
parser.add_argument(
"--conf", type=float, default=0.4, help="object confidence threshold"
)
parser.add_argument(
"--output", type=str, default="./output/", help="image save path"
)
parser.add_argument(
"--randomcolor", type=bool, default=True, help="mask random color"
)
parser.add_argument(
"--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]"
)
parser.add_argument(
"--point_label",
type=str,
default="[0]",
help="[1,0] 0:background, 1:foreground",
)
parser.add_argument("--box_prompt", type=str, default="[0,0,0,0]", help="[x,y,w,h]")
parser.add_argument(
"--better_quality",
type=str,
default=False,
help="better quality using morphologyEx",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument(
"--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
)
parser.add_argument(
"--retina",
type=bool,
default=True,
help="draw high-resolution segmentation masks",
)
parser.add_argument(
"--withContours", type=bool, default=False, help="draw the edges of the masks"
)
return parser.parse_args()
def main(args):
# Image Path
img_path = args.img_path
text = args.text
# path to save img
save_path = args.output
if not os.path.exists(save_path):
os.makedirs(save_path)
basename = os.path.basename(args.img_path).split(".")[0]
# Build Fast-SAM Model
# ckpt_path = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/FastSAM/FastSAM-x.pt"
model = YOLO(args.model_path)
results = model(
args.img_path,
imgsz=args.imgsz,
device=args.device,
retina_masks=args.retina,
iou=args.iou,
conf=args.conf,
max_det=100,
)
# Build GroundingDINO Model
groundingdino_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
groundingdino_ckpt_path = "./groundingdino_swint_ogc.pth"
image_source, image = load_image(img_path)
model = load_model(groundingdino_config, groundingdino_ckpt_path)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=text,
box_threshold=0.3,
text_threshold=0.25,
device=args.device,
)
# Grounded-Fast-SAM
ori_img = cv2.imread(img_path)
ori_h = ori_img.shape[0]
ori_w = ori_img.shape[1]
# Save each frame due to the post process from FastSAM
boxes = boxes * torch.Tensor([ori_w, ori_h, ori_w, ori_h])
print(f"Detected Boxes: {len(boxes)}")
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist()
for box_idx in range(len(boxes)):
mask, _ = box_prompt(
results[0].masks.data,
boxes[box_idx],
ori_h,
ori_w,
)
annotations = np.array([mask])
img_array = fast_process(
annotations=annotations,
args=args,
mask_random_color=True,
bbox=boxes[box_idx],
)
cv2.imwrite(os.path.join(save_path, basename + f"_{str(box_idx)}_caption_{phrases[box_idx]}.jpg"), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
if __name__ == "__main__":
args = parse_args()
main(args)