Primate-Detection-GPU / dino_sam.py
annading's picture
batch size 5 again
0f726c9
raw
history blame
9.6 kB
import datetime
import cv2
import os
import numpy as np
import torch
# import io
# import cProfile
import csv
# import pstats
import warnings
from memory_profiler import profile
# from pstats import SortKey
from tqdm import tqdm
from torchvision.ops import box_convert
from typing import Tuple
from GroundingDINO.groundingdino.util.inference import load_model, load_image, annotate, preprocess_caption
from GroundingDINO.groundingdino.util.utils import get_phrases_from_posmap
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from video_utils import mp4_to_png, frame_to_timestamp, vid_stitcher
warnings.filterwarnings("ignore")
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
# @profile
def sam_dino_vid(
vid_path: str,
text_prompt: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25,
fps_processed: int = 1,
video_options: list[str] = ["Bounding boxes"],
config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
weights_path: str = "weights/groundingdino_swint_ogc.pth",
device: str = 'cuda',
batch_size: int = 5
) -> (str, str):
""" Args:
Returns:
"""
masks_needed = False
boxes_needed = True
# if masks are selected, load SAM model
if "Bounding boxes" not in video_options:
boxes_needed = False
if "Masks" in video_options:
masks_needed = True
checkpoint = "weights/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
# create new dirs and paths for results
filename = os.path.splitext(os.path.basename(vid_path))[0]
results_dir = "../processed/" + filename + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(results_dir, exist_ok=True)
frames_dir = os.path.join(results_dir, "frames")
os.makedirs(frames_dir, exist_ok=True)
csv_path = os.path.join(results_dir, "detections.csv")
# load the groundingDINO model
gd_model = load_model(config_path, weights_path, device=device)
# process video and create a directory of video frames
fps = mp4_to_png(vid_path, frames_dir)
# get the frame paths for the images to process
frame_filenames = os.listdir(frames_dir)
frame_paths = [] # list of frame paths to process based on fps_processed
other_paths = [] # list of every frame path in the dir
for i, frame in enumerate(frame_filenames):
if i % fps_processed == 0:
frame_paths.append(os.path.join(frames_dir, frame))
else:
other_paths.append(os.path.join(frames_dir, frame))
# TODO: rename vars to be more clear
# run dino_predict_batch and sam_predict_batch in batches of frames
# write the results to a csv
with open(csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Frame", "Timestamp (hh:mm:ss)", "Boxes (cxcywh)", "# Boxes"])
# run groundingDINO in batches
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
batch_paths = frame_paths[i:i+batch_size] # paths for this batch
images_orig = [load_image(img)[0] for img in batch_paths]
image_stack = torch.stack([load_image(img)[1] for img in batch_paths])
boxes_i, logits_i, phrases_i = dino_predict_batch(
model=gd_model,
images=image_stack,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
# convert images_orig to rgb from bgr
images_orig_rgb = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig]
if masks_needed:
# run SAM in batches on boxes from dino
batched_input = []
sam_boxes = []
for image, box in zip(images_orig_rgb, boxes_i):
height, width = image.shape[:2]
# convert the boxes from groundingDINO format to SAM format
box = box * torch.Tensor([width, height, width, height])
box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy").cuda()
sam_boxes.append(box)
batched_input.append({
"image": prepare_image(image, resize_transform, sam),
"boxes": resize_transform.apply_boxes_torch(box, image.shape[:2]),
"original_size": image.shape[:2]
})
batched_output = sam(batched_input, multimask_output=False)
for i, prediction in enumerate(batched_output):
# write to annotated_frames_dir for stitching
mask = prediction["masks"].cpu().numpy()
box = sam_boxes[i].cpu().numpy()
annotated_frame = plot_sam(images_orig_rgb[i], mask, box, boxes_shown=boxes_needed)
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
elif boxes_needed and not masks_needed:
# get groundingDINO annotated frames
for i, (image, box, logit, phrase) in enumerate(zip(images_orig, boxes_i, logits_i, phrases_i)):
annotated_frame = annotate(image_source=image, boxes=box, logits=logit, phrases=phrase)
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
# write results to csv
# TODO: convert boxes to SAM format for clearer understanding
frame_names = [os.path.basename(frame_path).split(".")[0] for frame_path in batch_paths]
for i, frame in enumerate(frame_names):
writer.writerow([frame, frame_to_timestamp(int(frame[-8:]), fps), boxes_i[i], len(boxes_i[i])])
csvfile.close()
# stitch the frames
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"), fps=fps)
print("Results saved to: " + save_path)
return csv_path, save_path
def dino_predict_batch(
model,
images: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[list[str]]]:
'''
return:
bboxes_batch: list of tensors of shape (n, 4)
predicts_batch: list of tensors of shape (n,)
phrases_batch: list of list of strings of shape (n,)
'''
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = images.to(device)
with torch.no_grad():
outputs = model(image, captions=[caption for _ in range(len(images))])
prediction_logits = outputs["pred_logits"].cpu().sigmoid() # prediction_logits.shape = (num_batch, nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu() # prediction_boxes.shape = (num_batch, nq, 4)
mask = prediction_logits.max(dim=2)[0] > box_threshold # mask: torch.Size([num_batch, 256])
bboxes_batch = []
predicts_batch = []
phrases_batch = [] # list of lists
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
for i in range(prediction_logits.shape[0]):
logits = prediction_logits[i][mask[i]] # logits.shape = (n, 256)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit # logit is a tensor of shape (256,) torch.Size([256])
in logits # torch.Size([7, 256])
]
boxes = prediction_boxes[i][mask[i]] # boxes.shape = (n, 4)
phrases_batch.append(phrases)
bboxes_batch.append(boxes)
predicts_batch.append(logits.max(dim=1)[0])
return bboxes_batch, predicts_batch, phrases_batch
def plot_sam(
image: np.ndarray,
masks: list[np.ndarray],
boxes: np.ndarray,
boxes_shown: bool = True,
masks_shown: bool = True,
) -> np.ndarray:
"""
Plot image with masks and/or boxes.
"""
# Use cv2 to plot the boxes and masks if they exist
if boxes_shown:
for box in boxes:
# red bbox
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)
if masks_shown:
# blue mask
color = np.array([255, 144, 30])
color = color.astype(np.uint8)
for mask in masks:
# turn the mask into a colored mask
h, w = mask.shape[-2:]
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
image = cv2.addWeighted(image, 1, mask, 0.5, 0)
return image
# if __name__ == '__main__':
# def run_sam_dino_vid():
# sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks'])
# start_time = datetime.datetime.now()
# stats = run_sam_dino_vid()
# print("elapsed: " + str(datetime.datetime.now() - start_time))