R-FLAV / dataset.py
Alex Ergasti
Init
b89c182
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import os.path as osp
import math
import pickle
import warnings
import glob
import torch.utils.data as data
import torch.nn.functional as F
from torchvision.datasets.video_utils import VideoClips
from converter import normalize, normalize_spectrogram, get_mel_spectrogram_from_audio
from torchaudio import transforms as Ta
from torchvision import transforms as Tv
from torchvision.io.video import read_video
import torch
from torchvision.transforms import InterpolationMode
class LatentDataset(data.Dataset):
""" Generic dataset for latents pregenerated from a dataset
Returns a dictionary of latents encoded from the original dataset """
exts = ['pt']
def __init__(self, data_folder, train=True):
"""
Args:
data_folder: path to the folder with videos. The folder
should contain a 'train' and a 'test' directory,
each with corresponding videos stored
"""
super().__init__()
self.train = train
folder = osp.join(data_folder, 'train' if train else 'test')
self.files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
for ext in self.exts], [])
warnings.filterwarnings('ignore')
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
while True:
try:
latents = torch.load(self.files[idx], map_location="cpu")
except Exception as e:
print(f"Dataset Exception: {e}")
idx = (idx + 1) % len(self.files)
continue
break
return latents["video"], latents["audio"], latents["y"]
class AudioVideoDataset(data.Dataset):
""" Generic dataset for videos files stored in folders
Returns BCTHW videos in the range [-0.5, 0.5] """
exts = ['avi', 'mp4', 'webm']
def __init__(self, data_folder, train=True, resolution=64, sample_every_n_frames=1, sequence_length=8, audio_channels=1, sample_rate=16000, min_length=1, ignore_cache=False, labeled=True, target_video_fps=10):
"""
Args:
data_folder: path to the folder with videos. The folder
should contain a 'train' and a 'test' directory,
each with corresponding videos stored
sequence_length: length of extracted video sequences
"""
super().__init__()
self.train = train
self.sequence_length = sequence_length
self.resolution = resolution
self.sample_every_n_frames = sample_every_n_frames
self.audio_channels = audio_channels
self.sample_rate = sample_rate
self.min_length = min_length
self.labeled = labeled
folder = osp.join(data_folder, 'train' if train else 'test')
files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
for ext in self.exts], [])
# hacky way to compute # of classes (count # of unique parent directories)
self.classes = list(set([get_parent_dir(f) for f in files]))
self.classes.sort()
self.class_to_label = {c: i for i, c in enumerate(self.classes)}
warnings.filterwarnings('ignore')
cache_file = osp.join(folder, f"metadata_{self.sequence_length}.pkl")
if not osp.exists(cache_file) or ignore_cache or True:
clips = VideoClips(files, self.sequence_length, num_workers=32, frame_rate=target_video_fps)
# pickle.dump(clips.metadata, open(cache_file, 'wb'))
else:
metadata = pickle.load(open(cache_file, 'rb'))
clips = VideoClips(files, self.sequence_length,
_precomputed_metadata=metadata)
# self._clips = clips.subset(np.arange(24))
self._clips = clips
@property
def n_classes(self):
return len(self.classes)
def __len__(self):
return self._clips.num_clips()
def __getitem__(self, idx):
resolution = self.resolution
while True:
try:
video, _, info, _ = self._clips.get_clip(idx)
except Exception:
idx = (idx + 1) % self._clips.num_clips()
continue
break
return preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames), self.get_audio(info, idx), self.get_label(idx)
def get_label(self, idx):
if not self.labeled:
return -1
video_idx, clip_idx = self._clips.get_clip_location(idx)
class_name = get_parent_dir(self._clips.video_paths[video_idx])
label = self.class_to_label[class_name]
return label
def get_audio(self, info, idx):
video_idx, clip_idx = self._clips.get_clip_location(idx)
video_path = self._clips.video_paths[video_idx]
video_fps = self._clips.video_fps[video_idx]
duration_per_frame = self._clips.video_pts[video_idx][1] - self._clips.video_pts[video_idx][0]
clip_pts = self._clips.clips[video_idx][clip_idx]
clip_pid = clip_pts // duration_per_frame
start_t = (clip_pid[0] / video_fps * 1. ).item()
end_t = ((clip_pid[-1] + 1) / video_fps * 1. ).item()
_, raw_audio, _ = read_video(video_path,start_t, end_t, pts_unit='sec')
raw_audio = prepare_audio(raw_audio, info["audio_fps"], self.sample_rate, self.audio_channels, self.sequence_length, self.min_length)
_, spec = get_mel_spectrogram_from_audio(raw_audio[0].numpy())
norm_spec = normalize_spectrogram(spec)
norm_spec = normalize(norm_spec) # normalize to [-1, 1], because pipeline do not normalize for torch.Tensor input
norm_spec.unsqueeze(1) # add channel dimension
return norm_spec
#return raw_audio[0]
def get_parent_dir(path):
return osp.basename(osp.dirname(path))
def preprocess(video, resolution, sample_every_n_frames=1):
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
old_size = video.shape[2:4]
ratio = min(float(resolution)/(old_size[0]), float(resolution)/(old_size[1]) )
new_size = tuple([int(i*ratio) for i in old_size])
pad_w = resolution - new_size[1]
pad_h = resolution- new_size[0]
top,bottom = pad_h//2, pad_h-(pad_h//2)
left,right = pad_w//2, pad_w -(pad_w//2)
transform = Tv.Compose([Tv.Resize(new_size, interpolation=InterpolationMode.BICUBIC), Tv.Pad((left, top, right, bottom))])
video_new = transform(video)
video_new = video_new*2-1
return video_new
def pad_crop_audio(audio, target_length):
target_length = int(target_length)
n, s = audio.shape
start = 0
end = start + target_length
output = audio.new_zeros([n, target_length])
output[:, :min(s, target_length)] = audio[:, start:end]
return output
def prepare_audio(audio, in_sr, target_sr, target_channels, sequence_length, min_length):
if in_sr != target_sr:
resample_tf = Ta.Resample(in_sr, target_sr)
audio = resample_tf(audio)
max_length = target_sr/10*sequence_length
target_length = max_length + (min_length - (max_length % min_length)) % min_length
audio = pad_crop_audio(audio, target_length)
audio = set_audio_channels(audio, target_channels)
return audio
def set_audio_channels(audio, target_channels):
if target_channels == 1:
# Convert to mono
# audio = audio.mean(0, keepdim=True)
audio = audio[:1, :]
elif target_channels == 2:
# Convert to stereo
if audio.shape[0] == 1:
audio = audio.repeat(2, 1)
elif audio.shape[0] > 2:
audio = audio[:2, :]
return audio