File size: 5,448 Bytes
44f2ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#@title Get bounding boxes for the subject
from transformers import pipeline
from moviepy.editor import VideoFileClip
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import pickle
import torch

checkpoint = "google/owlvit-large-patch14"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", cache_dir="/coc/pskynet4/yashjain/", device='cuda:0')


# from transformers import Owlv2Processor, Owlv2ForObjectDetection

# processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
# model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

# def owl_inference(image, text):
#     inputs = inputs = processor(text=text, images=image, return_tensors="pt")
#     outputs = model(**inputs)
#     target_sizes = torch.Tensor([image.size[::-1]])
#     results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
#     return results[0]['boxes']

def find_surrounding_masks(mask_presence):
    # Finds the indices of the surrounding masks for each gap
    gap_info = []
    start = None

    for i, present in enumerate(mask_presence):
        if present and start is not None:
            end = i
            gap_info.append((start, end))
            start = None
        elif not present and start is None and i > 0:
            start = i - 1

    # Handle the special case where the gap is at the end
    if start is not None:
        gap_info.append((start, len(mask_presence)))
    
    return gap_info

def copy_edge_masks(mask_list, mask_presence):
    if not mask_presence[-1]:
        # Find the last present mask and copy it to the end
        for i in reversed(range(len(mask_presence))):
            if mask_presence[i]:
                mask_list[i+1:] = [mask_list[i]] * (len(mask_presence) - i - 1)
                break

def interpolate_masks(mask_list, mask_presence):
    # Ensure the mask list and mask presence list are the same length
    assert len(mask_list) == len(mask_presence), "Mask list and presence list must have the same length."

    # Copy edge masks if there are gaps at the start or end
    # copy_edge_masks(mask_list, mask_presence)

    # Find surrounding masks for gaps
    gap_info = find_surrounding_masks(mask_presence)

    # Interpolate the masks in the gaps
    for start, end in gap_info:
        end = min(end, len(mask_list)-1)
        num_steps = end - start - 1
        prev_mask = mask_list[start]
        next_mask = mask_list[end]
        step = (next_mask - prev_mask) / (num_steps + 1)
        interpolated_masks = [(prev_mask + step * (i + 1)).round().astype(int) for i in range(num_steps)]
        mask_list[start + 1:end] = interpolated_masks

    return mask_list

def get_bounding_boxes(clip_path, subject):
    # Read video from the path
    clip = VideoFileClip(clip_path)
    all_bboxes = []
    bbox_present = []

    num_bb = 0
    
    for fidx,frame in enumerate(clip.iter_frames()):
        if fidx > 24: break

        frame = Image.fromarray(frame)

        predictions = detector(
            frame,
            candidate_labels=[subject,], 
        )
        try:
            
            bbox = predictions[0]["box"]
            
            bbox = (bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"])
            
            # Get a zeros array of the same size as the frame
            canvas = np.zeros(frame.size[::-1])
            # Draw the bounding box on the canvas
            canvas[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1
            # Add the canvas to the list of bounding boxes
            all_bboxes.append(canvas)
            bbox_present.append(True)
            num_bb += 1
        except Exception as e:
            
            # Append an empty canvas, we will interpolate later
            all_bboxes.append(np.zeros(frame.size[::-1]))        
            bbox_present.append(False)    
            continue

    # Design decision
    interpolated_masks = interpolate_masks(all_bboxes, bbox_present)    
    return interpolated_masks, num_bb

import json
BASE_DIR = '/scr/clips_downsampled_5fps_downsized_224x224'
annotations = json.load(open('/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/ssv2_label_ssv2_template/ssv2_ret_label_val_small_filtered.json', 'r'))

records_with_masks = []
ridx = 0
for idx,record in tqdm.tqdm(enumerate(annotations)):
    video_id = record['video']
    print(f"{record['caption']} - {record['nouns']}")
    # for video_id in video_ids:
    new_record = record.copy()
    new_record['video'] = video_id.replace('webm', 'mp4')
    all_masks = []
    all_num_bb = []
    for subject in record['nouns']:
        masks, num_bb = get_bounding_boxes(clip_path=os.path.join(BASE_DIR, video_id.replace('webm', 'mp4')), subject=subject)
        all_masks.append(masks)
        all_num_bb.append(num_bb)
    try:    
        print(f"{record['video']} , subj - {record['nouns']}, bb - {all_num_bb}")
    except:
        continue
    new_record['masks'] = all_masks
    records_with_masks.append(new_record)
    ridx += 1

    if ridx % 100 == 0:
        with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f:
            pickle.dump(records_with_masks, f)

with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f:
    pickle.dump(records_with_masks, f)