peekaboo-demo / data /generate_ssv2_st.py
Anshul Nasery
Demo commit
44f2ca8
raw
history blame
5.45 kB
#@title Get bounding boxes for the subject
from transformers import pipeline
from moviepy.editor import VideoFileClip
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import pickle
import torch
checkpoint = "google/owlvit-large-patch14"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", cache_dir="/coc/pskynet4/yashjain/", device='cuda:0')
# from transformers import Owlv2Processor, Owlv2ForObjectDetection
# processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
# model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
# def owl_inference(image, text):
# inputs = inputs = processor(text=text, images=image, return_tensors="pt")
# outputs = model(**inputs)
# target_sizes = torch.Tensor([image.size[::-1]])
# results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
# return results[0]['boxes']
def find_surrounding_masks(mask_presence):
# Finds the indices of the surrounding masks for each gap
gap_info = []
start = None
for i, present in enumerate(mask_presence):
if present and start is not None:
end = i
gap_info.append((start, end))
start = None
elif not present and start is None and i > 0:
start = i - 1
# Handle the special case where the gap is at the end
if start is not None:
gap_info.append((start, len(mask_presence)))
return gap_info
def copy_edge_masks(mask_list, mask_presence):
if not mask_presence[-1]:
# Find the last present mask and copy it to the end
for i in reversed(range(len(mask_presence))):
if mask_presence[i]:
mask_list[i+1:] = [mask_list[i]] * (len(mask_presence) - i - 1)
break
def interpolate_masks(mask_list, mask_presence):
# Ensure the mask list and mask presence list are the same length
assert len(mask_list) == len(mask_presence), "Mask list and presence list must have the same length."
# Copy edge masks if there are gaps at the start or end
# copy_edge_masks(mask_list, mask_presence)
# Find surrounding masks for gaps
gap_info = find_surrounding_masks(mask_presence)
# Interpolate the masks in the gaps
for start, end in gap_info:
end = min(end, len(mask_list)-1)
num_steps = end - start - 1
prev_mask = mask_list[start]
next_mask = mask_list[end]
step = (next_mask - prev_mask) / (num_steps + 1)
interpolated_masks = [(prev_mask + step * (i + 1)).round().astype(int) for i in range(num_steps)]
mask_list[start + 1:end] = interpolated_masks
return mask_list
def get_bounding_boxes(clip_path, subject):
# Read video from the path
clip = VideoFileClip(clip_path)
all_bboxes = []
bbox_present = []
num_bb = 0
for fidx,frame in enumerate(clip.iter_frames()):
if fidx > 24: break
frame = Image.fromarray(frame)
predictions = detector(
frame,
candidate_labels=[subject,],
)
try:
bbox = predictions[0]["box"]
bbox = (bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"])
# Get a zeros array of the same size as the frame
canvas = np.zeros(frame.size[::-1])
# Draw the bounding box on the canvas
canvas[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1
# Add the canvas to the list of bounding boxes
all_bboxes.append(canvas)
bbox_present.append(True)
num_bb += 1
except Exception as e:
# Append an empty canvas, we will interpolate later
all_bboxes.append(np.zeros(frame.size[::-1]))
bbox_present.append(False)
continue
# Design decision
interpolated_masks = interpolate_masks(all_bboxes, bbox_present)
return interpolated_masks, num_bb
import json
BASE_DIR = '/scr/clips_downsampled_5fps_downsized_224x224'
annotations = json.load(open('/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/ssv2_label_ssv2_template/ssv2_ret_label_val_small_filtered.json', 'r'))
records_with_masks = []
ridx = 0
for idx,record in tqdm.tqdm(enumerate(annotations)):
video_id = record['video']
print(f"{record['caption']} - {record['nouns']}")
# for video_id in video_ids:
new_record = record.copy()
new_record['video'] = video_id.replace('webm', 'mp4')
all_masks = []
all_num_bb = []
for subject in record['nouns']:
masks, num_bb = get_bounding_boxes(clip_path=os.path.join(BASE_DIR, video_id.replace('webm', 'mp4')), subject=subject)
all_masks.append(masks)
all_num_bb.append(num_bb)
try:
print(f"{record['video']} , subj - {record['nouns']}, bb - {all_num_bb}")
except:
continue
new_record['masks'] = all_masks
records_with_masks.append(new_record)
ridx += 1
if ridx % 100 == 0:
with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f:
pickle.dump(records_with_masks, f)
with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f:
pickle.dump(records_with_masks, f)