Spaces:
Running
on
A10G
Running
on
A10G
File size: 7,435 Bytes
8b54513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import math
import torch
import torch.nn as nn
from pytorchvideo import transforms as pv_transforms
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
from torchvision import transforms
from torchvision.transforms._transforms_video import NormalizeVideo
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def crop_boxes(boxes, x_offset, y_offset):
"""
Perform crop on the bounding boxes given the offsets.
Args:
boxes (ndarray or None): bounding boxes to perform crop. The dimension
is `num boxes` x 4.
x_offset (int): cropping offset in the x axis.
y_offset (int): cropping offset in the y axis.
Returns:
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
cropped_boxes = boxes.copy()
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
return cropped_boxes
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
"""
Perform uniform spatial sampling on the images and corresponding boxes.
Args:
images (tensor): images to perform uniform crop. The dimension is
`num frames` x `channel` x `height` x `width`.
size (int): size of height and weight to crop the images.
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
is larger than height. Or 0, 1, or 2 for top, center, and bottom
crop if height is larger than width.
boxes (ndarray or None): optional. Corresponding boxes to images.
Dimension is `num boxes` x 4.
scale_size (int): optinal. If not None, resize the images to scale_size before
performing any crop.
Returns:
cropped (tensor): images with dimension of
`num frames` x `channel` x `size` x `size`.
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
assert spatial_idx in [0, 1, 2]
ndim = len(images.shape)
if ndim == 3:
images = images.unsqueeze(0)
height = images.shape[2]
width = images.shape[3]
if scale_size is not None:
if width <= height:
width, height = scale_size, int(height / width * scale_size)
else:
width, height = int(width / height * scale_size), scale_size
images = torch.nn.functional.interpolate(
images,
size=(height, width),
mode="bilinear",
align_corners=False,
)
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
if height > width:
if spatial_idx == 0:
y_offset = 0
elif spatial_idx == 2:
y_offset = height - size
else:
if spatial_idx == 0:
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
return cropped, cropped_boxes
class SpatialCrop(nn.Module):
"""
Convert the video into 3 smaller clips spatially. Must be used after the
temporal crops to get spatial crops, and should be used with
-2 in the spatial crop at the slowfast augmentation stage (so full
frames are passed in here). Will return a larger list with the
3x spatial crops as well.
"""
def __init__(self, crop_size: int = 224, num_crops: int = 3):
super().__init__()
self.crop_size = crop_size
if num_crops == 3:
self.crops_to_ext = [0, 1, 2]
self.flipped_crops_to_ext = []
elif num_crops == 1:
self.crops_to_ext = [1]
self.flipped_crops_to_ext = []
else:
raise NotImplementedError("Nothing else supported yet")
def forward(self, videos):
"""
Args:
videos: A list of C, T, H, W videos.
Returns:
videos: A list with 3x the number of elements. Each video converted
to C, T, H', W' by spatial cropping.
"""
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
res = []
for video in videos:
for spatial_idx in self.crops_to_ext:
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
if not self.flipped_crops_to_ext:
continue
flipped_video = transforms.functional.hflip(video)
for spatial_idx in self.flipped_crops_to_ext:
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
return res
def load_and_transform_video_data(
video_file,
video_path,
clip_duration=2,
clips_per_video=5,
sample_rate=16000,
with_audio=False
):
video_transform = transforms.Compose(
[
pv_transforms.ShortSideScale(224),
NormalizeVideo(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
clip_sampler = ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
if isinstance(video_file, str):
video = EncodedVideo.from_path(
video_file,
decoder="decord",
decode_audio=with_audio,
# **{"sample_rate": sample_rate},
)
else:
video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
all_video = []
for clip_timepoints in all_clips_timepoints:
# Read the clip, get frames
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
if clip is None:
raise ValueError("No clip found")
video_clip = frame_sampler(clip["video"])
video_clip = video_clip / 255.0 # since this is float, need 0-1
all_video.append(video_clip)
all_video = [video_transform(clip) for clip in all_video]
all_video = SpatialCrop(224, num_crops=3)(all_video)
all_video = torch.stack(all_video, dim=0)
if not with_audio:
return all_video
else:
return all_video, clip['audio']
if __name__ == '__main__':
video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
import pdb;pdb.set_trace() |