Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
import os.path as osp | |
import math | |
import pickle | |
import warnings | |
import glob | |
import torch.utils.data as data | |
import torch.nn.functional as F | |
from torchvision.datasets.video_utils import VideoClips | |
from converter import normalize, normalize_spectrogram, get_mel_spectrogram_from_audio | |
from torchaudio import transforms as Ta | |
from torchvision import transforms as Tv | |
from torchvision.io.video import read_video | |
import torch | |
from torchvision.transforms import InterpolationMode | |
class LatentDataset(data.Dataset): | |
""" Generic dataset for latents pregenerated from a dataset | |
Returns a dictionary of latents encoded from the original dataset """ | |
exts = ['pt'] | |
def __init__(self, data_folder, train=True): | |
""" | |
Args: | |
data_folder: path to the folder with videos. The folder | |
should contain a 'train' and a 'test' directory, | |
each with corresponding videos stored | |
""" | |
super().__init__() | |
self.train = train | |
folder = osp.join(data_folder, 'train' if train else 'test') | |
self.files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True) | |
for ext in self.exts], []) | |
warnings.filterwarnings('ignore') | |
def __len__(self): | |
return len(self.files) | |
def __getitem__(self, idx): | |
while True: | |
try: | |
latents = torch.load(self.files[idx], map_location="cpu") | |
except Exception as e: | |
print(f"Dataset Exception: {e}") | |
idx = (idx + 1) % len(self.files) | |
continue | |
break | |
return latents["video"], latents["audio"], latents["y"] | |
class AudioVideoDataset(data.Dataset): | |
""" Generic dataset for videos files stored in folders | |
Returns BCTHW videos in the range [-0.5, 0.5] """ | |
exts = ['avi', 'mp4', 'webm'] | |
def __init__(self, data_folder, train=True, resolution=64, sample_every_n_frames=1, sequence_length=8, audio_channels=1, sample_rate=16000, min_length=1, ignore_cache=False, labeled=True, target_video_fps=10): | |
""" | |
Args: | |
data_folder: path to the folder with videos. The folder | |
should contain a 'train' and a 'test' directory, | |
each with corresponding videos stored | |
sequence_length: length of extracted video sequences | |
""" | |
super().__init__() | |
self.train = train | |
self.sequence_length = sequence_length | |
self.resolution = resolution | |
self.sample_every_n_frames = sample_every_n_frames | |
self.audio_channels = audio_channels | |
self.sample_rate = sample_rate | |
self.min_length = min_length | |
self.labeled = labeled | |
folder = osp.join(data_folder, 'train' if train else 'test') | |
files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True) | |
for ext in self.exts], []) | |
# hacky way to compute # of classes (count # of unique parent directories) | |
self.classes = list(set([get_parent_dir(f) for f in files])) | |
self.classes.sort() | |
self.class_to_label = {c: i for i, c in enumerate(self.classes)} | |
warnings.filterwarnings('ignore') | |
cache_file = osp.join(folder, f"metadata_{self.sequence_length}.pkl") | |
if not osp.exists(cache_file) or ignore_cache or True: | |
clips = VideoClips(files, self.sequence_length, num_workers=32, frame_rate=target_video_fps) | |
# pickle.dump(clips.metadata, open(cache_file, 'wb')) | |
else: | |
metadata = pickle.load(open(cache_file, 'rb')) | |
clips = VideoClips(files, self.sequence_length, | |
_precomputed_metadata=metadata) | |
# self._clips = clips.subset(np.arange(24)) | |
self._clips = clips | |
def n_classes(self): | |
return len(self.classes) | |
def __len__(self): | |
return self._clips.num_clips() | |
def __getitem__(self, idx): | |
resolution = self.resolution | |
while True: | |
try: | |
video, _, info, _ = self._clips.get_clip(idx) | |
except Exception: | |
idx = (idx + 1) % self._clips.num_clips() | |
continue | |
break | |
return preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames), self.get_audio(info, idx), self.get_label(idx) | |
def get_label(self, idx): | |
if not self.labeled: | |
return -1 | |
video_idx, clip_idx = self._clips.get_clip_location(idx) | |
class_name = get_parent_dir(self._clips.video_paths[video_idx]) | |
label = self.class_to_label[class_name] | |
return label | |
def get_audio(self, info, idx): | |
video_idx, clip_idx = self._clips.get_clip_location(idx) | |
video_path = self._clips.video_paths[video_idx] | |
video_fps = self._clips.video_fps[video_idx] | |
duration_per_frame = self._clips.video_pts[video_idx][1] - self._clips.video_pts[video_idx][0] | |
clip_pts = self._clips.clips[video_idx][clip_idx] | |
clip_pid = clip_pts // duration_per_frame | |
start_t = (clip_pid[0] / video_fps * 1. ).item() | |
end_t = ((clip_pid[-1] + 1) / video_fps * 1. ).item() | |
_, raw_audio, _ = read_video(video_path,start_t, end_t, pts_unit='sec') | |
raw_audio = prepare_audio(raw_audio, info["audio_fps"], self.sample_rate, self.audio_channels, self.sequence_length, self.min_length) | |
_, spec = get_mel_spectrogram_from_audio(raw_audio[0].numpy()) | |
norm_spec = normalize_spectrogram(spec) | |
norm_spec = normalize(norm_spec) # normalize to [-1, 1], because pipeline do not normalize for torch.Tensor input | |
norm_spec.unsqueeze(1) # add channel dimension | |
return norm_spec | |
#return raw_audio[0] | |
def get_parent_dir(path): | |
return osp.basename(osp.dirname(path)) | |
def preprocess(video, resolution, sample_every_n_frames=1): | |
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW | |
old_size = video.shape[2:4] | |
ratio = min(float(resolution)/(old_size[0]), float(resolution)/(old_size[1]) ) | |
new_size = tuple([int(i*ratio) for i in old_size]) | |
pad_w = resolution - new_size[1] | |
pad_h = resolution- new_size[0] | |
top,bottom = pad_h//2, pad_h-(pad_h//2) | |
left,right = pad_w//2, pad_w -(pad_w//2) | |
transform = Tv.Compose([Tv.Resize(new_size, interpolation=InterpolationMode.BICUBIC), Tv.Pad((left, top, right, bottom))]) | |
video_new = transform(video) | |
video_new = video_new*2-1 | |
return video_new | |
def pad_crop_audio(audio, target_length): | |
target_length = int(target_length) | |
n, s = audio.shape | |
start = 0 | |
end = start + target_length | |
output = audio.new_zeros([n, target_length]) | |
output[:, :min(s, target_length)] = audio[:, start:end] | |
return output | |
def prepare_audio(audio, in_sr, target_sr, target_channels, sequence_length, min_length): | |
if in_sr != target_sr: | |
resample_tf = Ta.Resample(in_sr, target_sr) | |
audio = resample_tf(audio) | |
max_length = target_sr/10*sequence_length | |
target_length = max_length + (min_length - (max_length % min_length)) % min_length | |
audio = pad_crop_audio(audio, target_length) | |
audio = set_audio_channels(audio, target_channels) | |
return audio | |
def set_audio_channels(audio, target_channels): | |
if target_channels == 1: | |
# Convert to mono | |
# audio = audio.mean(0, keepdim=True) | |
audio = audio[:1, :] | |
elif target_channels == 2: | |
# Convert to stereo | |
if audio.shape[0] == 1: | |
audio = audio.repeat(2, 1) | |
elif audio.shape[0] > 2: | |
audio = audio[:2, :] | |
return audio |