|
import datetime |
|
import cv2 |
|
import os |
|
import numpy as np |
|
import torch |
|
|
|
|
|
import csv |
|
|
|
import warnings |
|
from memory_profiler import profile |
|
|
|
from tqdm import tqdm |
|
from torchvision.ops import box_convert |
|
from typing import Tuple |
|
from GroundingDINO.groundingdino.util.inference import load_model, load_image, annotate, preprocess_caption |
|
from GroundingDINO.groundingdino.util.utils import get_phrases_from_posmap |
|
from segment_anything import sam_model_registry |
|
from segment_anything.utils.transforms import ResizeLongestSide |
|
from video_utils import mp4_to_png, frame_to_timestamp, vid_stitcher |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
def prepare_image(image, transform, device): |
|
image = transform.apply_image(image) |
|
image = torch.as_tensor(image, device=device.device) |
|
return image.permute(2, 0, 1).contiguous() |
|
|
|
|
|
def sam_dino_vid( |
|
vid_path: str, |
|
text_prompt: str, |
|
box_threshold: float = 0.35, |
|
text_threshold: float = 0.25, |
|
fps_processed: int = 1, |
|
video_options: list[str] = ["Bounding boxes"], |
|
config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", |
|
weights_path: str = "weights/groundingdino_swint_ogc.pth", |
|
device: str = 'cuda', |
|
batch_size: int = 5 |
|
) -> (str, str): |
|
""" Args: |
|
Returns: |
|
""" |
|
|
|
masks_needed = False |
|
boxes_needed = True |
|
|
|
if "Bounding boxes" not in video_options: |
|
boxes_needed = False |
|
if "Masks" in video_options: |
|
masks_needed = True |
|
checkpoint = "weights/sam_vit_h_4b8939.pth" |
|
model_type = "vit_h" |
|
sam = sam_model_registry[model_type](checkpoint=checkpoint) |
|
sam.to(device=device) |
|
resize_transform = ResizeLongestSide(sam.image_encoder.img_size) |
|
|
|
|
|
filename = os.path.splitext(os.path.basename(vid_path))[0] |
|
results_dir = "../processed/" + filename + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
os.makedirs(results_dir, exist_ok=True) |
|
frames_dir = os.path.join(results_dir, "frames") |
|
os.makedirs(frames_dir, exist_ok=True) |
|
csv_path = os.path.join(results_dir, "detections.csv") |
|
|
|
|
|
gd_model = load_model(config_path, weights_path, device=device) |
|
|
|
|
|
fps = mp4_to_png(vid_path, frames_dir) |
|
|
|
|
|
frame_filenames = os.listdir(frames_dir) |
|
|
|
frame_paths = [] |
|
other_paths = [] |
|
for i, frame in enumerate(frame_filenames): |
|
if i % fps_processed == 0: |
|
frame_paths.append(os.path.join(frames_dir, frame)) |
|
else: |
|
other_paths.append(os.path.join(frames_dir, frame)) |
|
|
|
|
|
|
|
|
|
with open(csv_path, 'w', newline='') as csvfile: |
|
writer = csv.writer(csvfile) |
|
writer.writerow(["Frame", "Timestamp (hh:mm:ss)", "Boxes (cxcywh)", "# Boxes"]) |
|
|
|
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): |
|
batch_paths = frame_paths[i:i+batch_size] |
|
images_orig = [load_image(img)[0] for img in batch_paths] |
|
image_stack = torch.stack([load_image(img)[1] for img in batch_paths]) |
|
boxes_i, logits_i, phrases_i = dino_predict_batch( |
|
model=gd_model, |
|
images=image_stack, |
|
caption=text_prompt, |
|
box_threshold=box_threshold, |
|
text_threshold=text_threshold |
|
) |
|
|
|
annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths] |
|
|
|
images_orig_rgb = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig] |
|
|
|
if masks_needed: |
|
|
|
batched_input = [] |
|
sam_boxes = [] |
|
for image, box in zip(images_orig_rgb, boxes_i): |
|
height, width = image.shape[:2] |
|
|
|
box = box * torch.Tensor([width, height, width, height]) |
|
box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy").cuda() |
|
sam_boxes.append(box) |
|
batched_input.append({ |
|
"image": prepare_image(image, resize_transform, sam), |
|
"boxes": resize_transform.apply_boxes_torch(box, image.shape[:2]), |
|
"original_size": image.shape[:2] |
|
}) |
|
batched_output = sam(batched_input, multimask_output=False) |
|
for i, prediction in enumerate(batched_output): |
|
|
|
mask = prediction["masks"].cpu().numpy() |
|
box = sam_boxes[i].cpu().numpy() |
|
annotated_frame = plot_sam(images_orig_rgb[i], mask, box, boxes_shown=boxes_needed) |
|
cv2.imwrite(annotated_frame_paths[i], annotated_frame) |
|
|
|
elif boxes_needed and not masks_needed: |
|
|
|
for i, (image, box, logit, phrase) in enumerate(zip(images_orig, boxes_i, logits_i, phrases_i)): |
|
annotated_frame = annotate(image_source=image, boxes=box, logits=logit, phrases=phrase) |
|
cv2.imwrite(annotated_frame_paths[i], annotated_frame) |
|
|
|
|
|
|
|
frame_names = [os.path.basename(frame_path).split(".")[0] for frame_path in batch_paths] |
|
for i, frame in enumerate(frame_names): |
|
writer.writerow([frame, frame_to_timestamp(int(frame[-8:]), fps), boxes_i[i], len(boxes_i[i])]) |
|
csvfile.close() |
|
|
|
|
|
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"), fps=fps) |
|
print("Results saved to: " + save_path) |
|
return csv_path, save_path |
|
|
|
|
|
def dino_predict_batch( |
|
model, |
|
images: torch.Tensor, |
|
caption: str, |
|
box_threshold: float, |
|
text_threshold: float, |
|
device: str = "cuda" |
|
) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[list[str]]]: |
|
''' |
|
return: |
|
bboxes_batch: list of tensors of shape (n, 4) |
|
predicts_batch: list of tensors of shape (n,) |
|
phrases_batch: list of list of strings of shape (n,) |
|
''' |
|
caption = preprocess_caption(caption=caption) |
|
model = model.to(device) |
|
image = images.to(device) |
|
with torch.no_grad(): |
|
outputs = model(image, captions=[caption for _ in range(len(images))]) |
|
prediction_logits = outputs["pred_logits"].cpu().sigmoid() |
|
prediction_boxes = outputs["pred_boxes"].cpu() |
|
|
|
mask = prediction_logits.max(dim=2)[0] > box_threshold |
|
|
|
bboxes_batch = [] |
|
predicts_batch = [] |
|
phrases_batch = [] |
|
tokenizer = model.tokenizer |
|
tokenized = tokenizer(caption) |
|
for i in range(prediction_logits.shape[0]): |
|
logits = prediction_logits[i][mask[i]] |
|
phrases = [ |
|
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') |
|
for logit |
|
in logits |
|
] |
|
boxes = prediction_boxes[i][mask[i]] |
|
phrases_batch.append(phrases) |
|
bboxes_batch.append(boxes) |
|
predicts_batch.append(logits.max(dim=1)[0]) |
|
|
|
return bboxes_batch, predicts_batch, phrases_batch |
|
|
|
def plot_sam( |
|
image: np.ndarray, |
|
masks: list[np.ndarray], |
|
boxes: np.ndarray, |
|
boxes_shown: bool = True, |
|
masks_shown: bool = True, |
|
) -> np.ndarray: |
|
""" |
|
Plot image with masks and/or boxes. |
|
""" |
|
|
|
if boxes_shown: |
|
for box in boxes: |
|
|
|
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2) |
|
if masks_shown: |
|
|
|
color = np.array([255, 144, 30]) |
|
color = color.astype(np.uint8) |
|
for mask in masks: |
|
|
|
h, w = mask.shape[-2:] |
|
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
image = cv2.addWeighted(image, 1, mask, 0.5, 0) |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|