import datetime import cv2 import os import numpy as np import torch # import io # import cProfile import csv # import pstats import warnings from memory_profiler import profile # from pstats import SortKey from tqdm import tqdm from torchvision.ops import box_convert from typing import Tuple from GroundingDINO.groundingdino.util.inference import load_model, load_image, annotate, preprocess_caption from GroundingDINO.groundingdino.util.utils import get_phrases_from_posmap from segment_anything import sam_model_registry from segment_anything.utils.transforms import ResizeLongestSide from video_utils import mp4_to_png, frame_to_timestamp, vid_stitcher warnings.filterwarnings("ignore") def prepare_image(image, transform, device): image = transform.apply_image(image) image = torch.as_tensor(image, device=device.device) return image.permute(2, 0, 1).contiguous() # @profile def sam_dino_vid( vid_path: str, text_prompt: str, box_threshold: float = 0.35, text_threshold: float = 0.25, fps_processed: int = 1, scaling_factor: float = 1.0, video_options: list[str] = ["Bounding boxes"], config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", weights_path: str = "weights/groundingdino_swint_ogc.pth", device: str = 'cuda', batch_size: int = 5 ) -> (str, str): """ Args: Returns: """ masks_needed = False boxes_needed = True # if masks are selected, load SAM model if "Bounding boxes" not in video_options: boxes_needed = False if "Masks" in video_options: masks_needed = True checkpoint = "weights/sam_vit_h_4b8939.pth" model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=checkpoint) sam.to(device=device) resize_transform = ResizeLongestSide(sam.image_encoder.img_size) # create new dirs and paths for results filename = os.path.splitext(os.path.basename(vid_path))[0] results_dir = "../processed/" + filename + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") os.makedirs(results_dir, exist_ok=True) frames_dir = os.path.join(results_dir, "frames") os.makedirs(frames_dir, exist_ok=True) csv_path = os.path.join(results_dir, "detections.csv") # load the groundingDINO model gd_model = load_model(config_path, weights_path, device=device) # process video and create a directory of video frames fps = mp4_to_png(vid_path, frames_dir, scaling_factor) # get the frame paths for the images to process frame_filenames = os.listdir(frames_dir) frame_paths = [] # list of frame paths to process based on fps_processed other_paths = [] # list of every frame path in the dir for i, frame in enumerate(frame_filenames): if i % fps_processed == 0: frame_paths.append(os.path.join(frames_dir, frame)) else: other_paths.append(os.path.join(frames_dir, frame)) # TODO: rename vars to be more clear # run dino_predict_batch and sam_predict_batch in batches of frames # write the results to a csv with open(csv_path, 'w', newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(["Frame", "Timestamp (hh:mm:ss)", "Boxes (cxcywh)", "# Boxes"]) # run groundingDINO in batches for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): batch_paths = frame_paths[i:i+batch_size] # paths for this batch images_orig = [load_image(img)[0] for img in batch_paths] image_stack = torch.stack([load_image(img)[1] for img in batch_paths]) boxes_i, logits_i, phrases_i = dino_predict_batch( model=gd_model, images=image_stack, caption=text_prompt, box_threshold=box_threshold, text_threshold=text_threshold ) annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths] # convert images_orig to rgb from bgr images_orig_rgb = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig] if masks_needed: # run SAM in batches on boxes from dino batched_input = [] sam_boxes = [] for image, box in zip(images_orig_rgb, boxes_i): height, width = image.shape[:2] # convert the boxes from groundingDINO format to SAM format box = box * torch.Tensor([width, height, width, height]) box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy").cuda() sam_boxes.append(box) batched_input.append({ "image": prepare_image(image, resize_transform, sam), "boxes": resize_transform.apply_boxes_torch(box, image.shape[:2]), "original_size": image.shape[:2] }) batched_output = sam(batched_input, multimask_output=False) for i, prediction in enumerate(batched_output): # write to annotated_frames_dir for stitching mask = prediction["masks"].cpu().numpy() box = sam_boxes[i].cpu().numpy() annotated_frame = plot_sam(images_orig_rgb[i], mask, box, boxes_shown=boxes_needed) cv2.imwrite(annotated_frame_paths[i], annotated_frame) elif boxes_needed and not masks_needed: # get groundingDINO annotated frames for i, (image, box, logit, phrase) in enumerate(zip(images_orig, boxes_i, logits_i, phrases_i)): annotated_frame = annotate(image_source=image, boxes=box, logits=logit, phrases=phrase) cv2.imwrite(annotated_frame_paths[i], annotated_frame) # write results to csv # TODO: convert boxes to SAM format for clearer understanding frame_names = [os.path.basename(frame_path).split(".")[0] for frame_path in batch_paths] for i, frame in enumerate(frame_names): writer.writerow([frame, frame_to_timestamp(int(frame[-8:]), fps), boxes_i[i], len(boxes_i[i])]) csvfile.close() # stitch the frames save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"), fps=fps) print("Results saved to: " + save_path) return csv_path, save_path def dino_predict_batch( model, images: torch.Tensor, caption: str, box_threshold: float, text_threshold: float, device: str = "cuda" ) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[list[str]]]: ''' return: bboxes_batch: list of tensors of shape (n, 4) predicts_batch: list of tensors of shape (n,) phrases_batch: list of list of strings of shape (n,) ''' caption = preprocess_caption(caption=caption) model = model.to(device) image = images.to(device) with torch.no_grad(): outputs = model(image, captions=[caption for _ in range(len(images))]) prediction_logits = outputs["pred_logits"].cpu().sigmoid() # prediction_logits.shape = (num_batch, nq, 256) prediction_boxes = outputs["pred_boxes"].cpu() # prediction_boxes.shape = (num_batch, nq, 4) mask = prediction_logits.max(dim=2)[0] > box_threshold # mask: torch.Size([num_batch, 256]) bboxes_batch = [] predicts_batch = [] phrases_batch = [] # list of lists tokenizer = model.tokenizer tokenized = tokenizer(caption) for i in range(prediction_logits.shape[0]): logits = prediction_logits[i][mask[i]] # logits.shape = (n, 256) phrases = [ get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') for logit # logit is a tensor of shape (256,) torch.Size([256]) in logits # torch.Size([7, 256]) ] boxes = prediction_boxes[i][mask[i]] # boxes.shape = (n, 4) phrases_batch.append(phrases) bboxes_batch.append(boxes) predicts_batch.append(logits.max(dim=1)[0]) return bboxes_batch, predicts_batch, phrases_batch def plot_sam( image: np.ndarray, masks: list[np.ndarray], boxes: np.ndarray, boxes_shown: bool = True, masks_shown: bool = True, ) -> np.ndarray: """ Plot image with masks and/or boxes. """ # Use cv2 to plot the boxes and masks if they exist if boxes_shown: for box in boxes: # red bbox cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2) if masks_shown: # blue mask color = np.array([255, 144, 30]) color = color.astype(np.uint8) for mask in masks: # turn the mask into a colored mask h, w = mask.shape[-2:] mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) image = cv2.addWeighted(image, 1, mask, 0.5, 0) return image # if __name__ == '__main__': # def run_sam_dino_vid(): # sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks']) # start_time = datetime.datetime.now() # stats = run_sam_dino_vid() # print("elapsed: " + str(datetime.datetime.now() - start_time))