import random import sys from typing import Dict from typing import List import numpy as np import supervision as sv import torch import torchvision import torchvision.transforms as T from huggingface_hub import hf_hub_download from PIL import Image from segment_anything import SamPredictor # segment anything sys.path.append("tag2text") sys.path.append("GroundingDINO") from groundingdino.models import build_model from groundingdino.util.inference import Model as DinoModel from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict from tag2text.inference import inference as tag2text_inference def load_model_hf(repo_id, filename, ckpt_config_filename, device="cpu"): cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename) args = SLConfig.fromfile(cache_config_file) args.device = device model = build_model(args) cache_file = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(cache_file, map_location=device) model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) model.eval() return model def download_file_hf(repo_id, filename, cache_dir="./cache"): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, force_filename=filename, cache_dir=cache_dir ) return cache_file def transform_image_tag2text(image_pil: Image) -> torch.Tensor: transform = T.Compose( [ T.Resize((384, 384)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image = transform(image_pil) # 3, h, w return image def show_anns_sam(anns: List[Dict]): """Extracts the mask annotations from the Segment Anything model output and plots them. https://github.com/facebookresearch/segment-anything. Arguments: anns (List[Dict]): Segment Anything model output. Returns: (np.ndarray): Masked image. (np.ndarray): annotation encoding from https://github.com/LUSSeg/ImageNet-S """ if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) full_img = None # for ann in sorted_anns: for i in range(len(sorted_anns)): ann = anns[i] m = ann["segmentation"] if full_img is None: full_img = np.zeros((m.shape[0], m.shape[1], 3)) map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) map[m != 0] = i + 1 color_mask = np.random.random((1, 3)).tolist()[0] full_img[m != 0] = color_mask full_img = full_img * 255 # anno encoding from https://github.com/LUSSeg/ImageNet-S res = np.zeros((map.shape[0], map.shape[1], 3)) res[:, :, 0] = map % 256 res[:, :, 1] = map // 256 res.astype(np.float32) full_img = np.uint8(full_img) return full_img, res def show_anns_sv(detections: sv.Detections): """Extracts the mask annotations from the Supervision Detections object. https://roboflow.github.io/supervision/detection/core/. Arguments: anns (sv.Detections): Containing information about the detections. Returns: (np.ndarray): Masked image. (np.ndarray): annotation encoding from https://github.com/LUSSeg/ImageNet-S """ if detections.mask is None: return full_img = None for i in np.flip(np.argsort(detections.area)): m = detections.mask[i] if full_img is None: full_img = np.zeros((m.shape[0], m.shape[1], 3)) map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) map[m != 0] = i + 1 color_mask = np.random.random((1, 3)).tolist()[0] full_img[m != 0] = color_mask full_img = full_img * 255 # anno encoding from https://github.com/LUSSeg/ImageNet-S res = np.zeros((map.shape[0], map.shape[1], 3)) res[:, :, 0] = map % 256 res[:, :, 1] = map // 256 res.astype(np.float32) full_img = np.uint8(full_img) return full_img, res def generate_tags(tag2text_model, image, specified_tags, device="cpu"): """Generate image tags and caption using Tag2Text model. Arguments: tag2text_model (nn.Module): Tag2Text model to use for prediction. image (np.ndarray): The image for calculating. Expects an image in HWC uint8 format, with pixel values in [0, 255]. specified_tags(str): User input specified tags Returns: (List[str]): Predicted image tags. (str): Predicted image caption """ image = transform_image_tag2text(image).unsqueeze(0).to(device) res = tag2text_inference(image, tag2text_model, specified_tags) tags = res[0].split(" | ") caption = res[2] return tags, caption def detect( grounding_dino_model: DinoModel, image: np.ndarray, caption: str, box_threshold: float = 0.3, text_threshold: float = 0.25, iou_threshold: float = 0.5, post_process: bool = True, ): """Detect bounding boxes for the given image, using the input caption. Arguments: grounding_dino_model (DinoModel): The model to use for detection. image (np.ndarray): The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255]. caption (str): Input caption contain object names to detect. To detect multiple objects, seperating each name with '.', like this: cat . dog . chair box_threshold (float): Box confidence threshold text_threshold (float): Text confidence threshold iou_threshold (float): IOU score threshold for post processing post_process (bool): If True, run NMS algorithm to remove duplicates segments. Returns: (sv.Detections): Containing information about the detections in a video frame. (str): Predicted phrases. (List[str]): Predicted classes. """ detections, phrases = grounding_dino_model.predict_with_caption( image=image, caption=caption, box_threshold=box_threshold, text_threshold=text_threshold, ) classes = list(map(lambda x: x.strip(), caption.split("."))) detections.class_id = DinoModel.phrases2classes(phrases=phrases, classes=classes) # NMS post process if post_process: # print(f"Before NMS: {len(detections.xyxy)} boxes") nms_idx = ( torchvision.ops.nms( torch.from_numpy(detections.xyxy), torch.from_numpy(detections.confidence), iou_threshold, ) .numpy() .tolist() ) phrases = [phrases[idx] for idx in nms_idx] detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] detections.class_id = detections.class_id[nms_idx] # print(f"After NMS: {len(detections.xyxy)} boxes") return detections, phrases, classes def segment(sam_model: SamPredictor, image: np.ndarray, boxes: np.ndarray): """Predict masks for the given input boxes, using the currently set image. Arguments: sam_model (SamPredictor): The model to use for mask prediction. image (np.ndarray): The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255]. boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (torch.Tensor): The output masks in BxCxHxW format, where C is the number of masks, and (H, W) is the original image size. (torch.Tensor): An array of shape BxC containing the model's predictions for the quality of each mask. (torch.Tensor): An array of shape BxCxHxW, where C is the number of masks and H=W=256. These low res logits can be passed to a subsequent iteration as mask input. """ sam_model.set_image(image) transformed_boxes = None if boxes is not None: boxes = torch.from_numpy(boxes) transformed_boxes = sam_model.transform.apply_boxes_torch( boxes.to(sam_model.device), image.shape[:2] ) masks, scores, _ = sam_model.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) masks = masks[:, 0, :, :] scores = scores[:, 0] return masks.cpu().numpy(), scores.cpu().numpy() def draw_mask(mask, draw, random_color=False): if random_color: color = ( random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153, ) else: color = (30, 144, 255, 153) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color)