#!/usr/bin/env python3 # Portions Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn import torchaudio import logging from models.multimodal_preprocessors import SimpleTokenizer from PIL import Image from pytorchvideo import transforms as pv_transforms from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler from pytorchvideo.data.encoded_video import EncodedVideo from torchvision import transforms from torchvision.transforms._transforms_video import NormalizeVideo DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102 waveform -= waveform.mean() fbank = torchaudio.compliance.kaldi.fbank( waveform, htk_compat=True, sample_frequency=sample_rate, use_energy=False, window_type="hanning", num_mel_bins=num_mel_bins, dither=0.0, frame_length=25, frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, ) # Convert to [mel_bins, num_frames] shape fbank = fbank.transpose(0, 1) # Pad to target_length n_frames = fbank.size(1) p = target_length - n_frames # if p is too large (say >20%), flash a warning if abs(p) / n_frames > 0.2: logging.warning( "Large gap between audio n_frames(%d) and " "target_length (%d). Is the audio_target_length " "setting correct?", n_frames, target_length, ) # cut and pad if p > 0: fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) elif p < 0: fbank = fbank[:, 0:target_length] # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 # channel image fbank = fbank.unsqueeze(0) return fbank def get_clip_timepoints(clip_sampler, duration): # Read out all clips in this video all_clips_timepoints = [] is_last_clip = False end = 0.0 while not is_last_clip: start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) all_clips_timepoints.append((start, end)) return all_clips_timepoints def load_and_transform_vision_data(image_paths, device): if image_paths is None: return None image_ouputs = [] for image_path in image_paths: data_transform = transforms.Compose( [ transforms.Resize( 224, interpolation=transforms.InterpolationMode.BICUBIC ), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) with open(image_path, "rb") as fopen: image = Image.open(fopen).convert("RGB") image = data_transform(image).to(device) image_ouputs.append(image) return torch.stack(image_ouputs, dim=0) def load_and_transform_text(text, device): if text is None: return None tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] tokens = torch.cat(tokens, dim=0) return tokens def load_and_transform_audio_data( audio_paths, device, num_mel_bins=128, target_length=204, sample_rate=16000, clip_duration=2, clips_per_video=3, mean=-4.268, std=9.138, ): if audio_paths is None: return None audio_outputs = [] clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video ) for audio_path in audio_paths: waveform, sr = torchaudio.load(audio_path) if sample_rate != sr: waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=sample_rate ) all_clips_timepoints = get_clip_timepoints( clip_sampler, waveform.size(1) / sample_rate ) all_clips = [] for clip_timepoints in all_clips_timepoints: waveform_clip = waveform[ :, int(clip_timepoints[0] * sample_rate) : int( clip_timepoints[1] * sample_rate ), ] waveform_melspec = waveform2melspec( waveform_clip, sample_rate, num_mel_bins, target_length ) all_clips.append(waveform_melspec) normalize = transforms.Normalize(mean=mean, std=std) all_clips = [normalize(ac).to(device) for ac in all_clips] all_clips = torch.stack(all_clips, dim=0) audio_outputs.append(all_clips) return torch.stack(audio_outputs, dim=0) def get_clip_timepoints(clip_sampler, duration): # Read out all clips in this video all_clips_timepoints = [] is_last_clip = False end = 0.0 while not is_last_clip: start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) all_clips_timepoints.append((start, end)) return all_clips_timepoints def crop_boxes(boxes, x_offset, y_offset): """ Peform crop on the bounding boxes given the offsets. Args: boxes (ndarray or None): bounding boxes to peform crop. The dimension is `num boxes` x 4. x_offset (int): cropping offset in the x axis. y_offset (int): cropping offset in the y axis. Returns: cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ cropped_boxes = boxes.copy() cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset return cropped_boxes def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): """ Perform uniform spatial sampling on the images and corresponding boxes. Args: images (tensor): images to perform uniform crop. The dimension is `num frames` x `channel` x `height` x `width`. size (int): size of height and weight to crop the images. spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width is larger than height. Or 0, 1, or 2 for top, center, and bottom crop if height is larger than width. boxes (ndarray or None): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. scale_size (int): optinal. If not None, resize the images to scale_size before performing any crop. Returns: cropped (tensor): images with dimension of `num frames` x `channel` x `size` x `size`. cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ assert spatial_idx in [0, 1, 2] ndim = len(images.shape) if ndim == 3: images = images.unsqueeze(0) height = images.shape[2] width = images.shape[3] if scale_size is not None: if width <= height: width, height = scale_size, int(height / width * scale_size) else: width, height = int(width / height * scale_size), scale_size images = torch.nn.functional.interpolate( images, size=(height, width), mode="bilinear", align_corners=False, ) y_offset = int(math.ceil((height - size) / 2)) x_offset = int(math.ceil((width - size) / 2)) if height > width: if spatial_idx == 0: y_offset = 0 elif spatial_idx == 2: y_offset = height - size else: if spatial_idx == 0: x_offset = 0 elif spatial_idx == 2: x_offset = width - size cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes class SpatialCrop(nn.Module): """ Convert the video into 3 smaller clips spatially. Must be used after the temporal crops to get spatial crops, and should be used with -2 in the spatial crop at the slowfast augmentation stage (so full frames are passed in here). Will return a larger list with the 3x spatial crops as well. """ def __init__(self, crop_size: int = 224, num_crops: int = 3): super().__init__() self.crop_size = crop_size if num_crops == 3: self.crops_to_ext = [0, 1, 2] self.flipped_crops_to_ext = [] elif num_crops == 1: self.crops_to_ext = [1] self.flipped_crops_to_ext = [] else: raise NotImplementedError("Nothing else supported yet") def forward(self, videos): """ Args: videos: A list of C, T, H, W videos. Returns: videos: A list with 3x the number of elements. Each video converted to C, T, H', W' by spatial cropping. """ assert isinstance(videos, list), "Must be a list of videos after temporal crops" assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" res = [] for video in videos: for spatial_idx in self.crops_to_ext: res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) if not self.flipped_crops_to_ext: continue flipped_video = transforms.functional.hflip(video) for spatial_idx in self.flipped_crops_to_ext: res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) return res def load_and_transform_video_data( video_paths, device, clip_duration=2, clips_per_video=5, sample_rate=16000, ): if video_paths is None: return None video_outputs = [] video_transform = transforms.Compose( [ pv_transforms.ShortSideScale(224), NormalizeVideo( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video ) frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) for video_path in video_paths: video = EncodedVideo.from_path( video_path, decoder="decord", decode_audio=False, **{"sample_rate": sample_rate}, ) all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) all_video = [] for clip_timepoints in all_clips_timepoints: # Read the clip, get frames clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) if clip is None: raise ValueError("No clip found") video_clip = frame_sampler(clip["video"]) video_clip = video_clip / 255.0 # since this is float, need 0-1 all_video.append(video_clip) all_video = [video_transform(clip) for clip in all_video] all_video = SpatialCrop(224, num_crops=3)(all_video) all_video = torch.stack(all_video, dim=0) video_outputs.append(all_video) return torch.stack(video_outputs, dim=0).to(device)