Spaces:
Running
on
T4
Running
on
T4
File size: 5,793 Bytes
dfd72c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
import json
import os
from typing import Dict, List, Optional, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util
def decode_video(video_path: str) -> List[np.ndarray]:
"""
Decode the video and return the RGB frames
"""
video = cv2.VideoCapture(video_path)
video_frames = []
while video.isOpened():
ret, frame = video.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_frames.append(frame)
else:
break
return video_frames
def show_anns(masks, colors: List, borders=True) -> None:
"""
show the annotations
"""
# return if no masks
if len(masks) == 0:
return
# sort masks by size
sorted_annot_and_color = sorted(
zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
)
H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
canvas = np.ones((H, W, 4))
canvas[:, :, 3] = 0 # set the alpha channel
contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
for mask, color in sorted_annot_and_color:
canvas[mask] = np.concatenate([color, [0.55]])
if borders:
contours, _ = cv2.findContours(
np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(
canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
)
ax = plt.gca()
ax.imshow(canvas)
class SAVDataset:
"""
SAVDataset is a class to load the SAV dataset and visualize the annotations.
"""
def __init__(self, sav_dir, annot_sample_rate=4):
"""
Args:
sav_dir: the directory of the SAV dataset
annot_sample_rate: the sampling rate of the annotations.
The annotations are aligned with the videos at 6 fps.
"""
self.sav_dir = sav_dir
self.annot_sample_rate = annot_sample_rate
self.manual_mask_colors = np.random.random((256, 3))
self.auto_mask_colors = np.random.random((256, 3))
def read_frames(self, mp4_path: str) -> None:
"""
Read the frames and downsample them to align with the annotations.
"""
if not os.path.exists(mp4_path):
print(f"{mp4_path} doesn't exist.")
return None
else:
# decode the video
frames = decode_video(mp4_path)
print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
# downsample the frames to align with the annotations
frames = frames[:: self.annot_sample_rate]
print(
f"Videos are annotated every {self.annot_sample_rate} frames. "
"To align with the annotations, "
f"downsample the video to {len(frames)} frames."
)
return frames
def get_frames_and_annotations(
self, video_id: str
) -> Tuple[List | None, Dict | None, Dict | None]:
"""
Get the frames and annotations for video.
"""
# load the video
mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
frames = self.read_frames(mp4_path)
if frames is None:
return None, None, None
# load the manual annotations
manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
if not os.path.exists(manual_annot_path):
print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
manual_annot = None
else:
manual_annot = json.load(open(manual_annot_path))
# load the manual annotations
auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
if not os.path.exists(auto_annot_path):
print(f"{auto_annot_path} doesn't exist.")
auto_annot = None
else:
auto_annot = json.load(open(auto_annot_path))
return frames, manual_annot, auto_annot
def visualize_annotation(
self,
frames: List[np.ndarray],
auto_annot: Optional[Dict],
manual_annot: Optional[Dict],
annotated_frame_id: int,
show_auto=True,
show_manual=True,
) -> None:
"""
Visualize the annotations on the annotated_frame_id.
If show_manual is True, show the manual annotations.
If show_auto is True, show the auto annotations.
By default, show both auto and manual annotations.
"""
if annotated_frame_id >= len(frames):
print("invalid annotated_frame_id")
return
rles = []
colors = []
if show_manual and manual_annot is not None:
rles.extend(manual_annot["masklet"][annotated_frame_id])
colors.extend(
self.manual_mask_colors[
: len(manual_annot["masklet"][annotated_frame_id])
]
)
if show_auto and auto_annot is not None:
rles.extend(auto_annot["masklet"][annotated_frame_id])
colors.extend(
self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
)
plt.imshow(frames[annotated_frame_id])
if len(rles) > 0:
masks = [mask_util.decode(rle) > 0 for rle in rles]
show_anns(masks, colors)
else:
print("No annotation will be shown")
plt.axis("off")
plt.show()
|