|
import os |
|
import math |
|
import os.path as osp |
|
import random |
|
import pickle |
|
import warnings |
|
|
|
import glob |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
import torch.utils.data as data |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
from torchvision.datasets.video_utils import VideoClips |
|
|
|
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] |
|
VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v'] |
|
|
|
|
|
def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1, |
|
batch_size=16, num_workers=8): |
|
data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers) |
|
loader = data._dataloader() |
|
return loader |
|
|
|
|
|
def is_image_file(filename): |
|
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) |
|
|
|
|
|
def get_parent_dir(path): |
|
return osp.basename(osp.dirname(path)) |
|
|
|
|
|
def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1): |
|
|
|
assert in_channels == 3 |
|
video = video.permute(0, 3, 1, 2).float() / 255. |
|
t, c, h, w = video.shape |
|
|
|
|
|
if sequence_length is not None: |
|
assert sequence_length <= t |
|
video = video[:sequence_length] |
|
|
|
|
|
if sample_every_n_frames > 1: |
|
video = video[::sample_every_n_frames] |
|
|
|
|
|
scale = resolution / min(h, w) |
|
if h < w: |
|
target_size = (resolution, math.ceil(w * scale)) |
|
else: |
|
target_size = (math.ceil(h * scale), resolution) |
|
video = F.interpolate(video, size=target_size, mode='bilinear', |
|
align_corners=False, antialias=True) |
|
|
|
|
|
t, c, h, w = video.shape |
|
w_start = (w - resolution) // 2 |
|
h_start = (h - resolution) // 2 |
|
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] |
|
video = video.permute(1, 0, 2, 3).contiguous() |
|
|
|
return {'video': video} |
|
|
|
|
|
def preprocess_image(image): |
|
|
|
img = torch.from_numpy(image) |
|
return img |
|
|
|
|
|
class VideoData(data.Dataset): |
|
""" Class to create dataloaders for video datasets |
|
|
|
Args: |
|
data_path: Path to the folder with video frames or videos. |
|
image_folder: If True, the data is stored as images in folders. |
|
resolution: Resolution of the returned videos. |
|
sequence_length: Length of extracted video sequences. |
|
sample_every_n_frames: Sample every n frames from the video. |
|
batch_size: Batch size. |
|
num_workers: Number of workers for the dataloader. |
|
shuffle: If True, shuffle the data. |
|
""" |
|
|
|
def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int, |
|
sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True): |
|
super().__init__() |
|
self.data_path = data_path |
|
self.image_folder = image_folder |
|
self.resolution = resolution |
|
self.sequence_length = sequence_length |
|
self.sample_every_n_frames = sample_every_n_frames |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.shuffle = shuffle |
|
|
|
def _dataset(self): |
|
''' |
|
Initializes and return the dataset. |
|
''' |
|
if self.image_folder: |
|
Dataset = FrameDataset |
|
dataset = Dataset(self.data_path, self.sequence_length, |
|
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames) |
|
else: |
|
Dataset = VideoDataset |
|
dataset = Dataset(self.data_path, self.sequence_length, |
|
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames) |
|
return dataset |
|
|
|
def _dataloader(self): |
|
''' |
|
Initializes and returns the dataloader. |
|
''' |
|
dataset = self._dataset() |
|
if dist.is_initialized(): |
|
sampler = data.distributed.DistributedSampler( |
|
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank() |
|
) |
|
else: |
|
sampler = None |
|
dataloader = data.DataLoader( |
|
dataset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=sampler is None and self.shuffle is True |
|
) |
|
return dataloader |
|
|
|
|
|
class VideoDataset(data.Dataset): |
|
""" |
|
Generic dataset for videos files stored in folders. |
|
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory. |
|
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos. |
|
Returns BCTHW videos in the range [0, 1]. |
|
|
|
Args: |
|
data_folder: Path to the folder with corresponding videos stored. |
|
sequence_length: Length of extracted video sequences. |
|
resolution: Resolution of the returned videos. |
|
sample_every_n_frames: Sample every n frames from the video. |
|
""" |
|
|
|
def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1): |
|
super().__init__() |
|
self.sequence_length = sequence_length |
|
self.resolution = resolution |
|
self.sample_every_n_frames = sample_every_n_frames |
|
|
|
folder = data_folder |
|
files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True) |
|
for ext in VID_EXTENSIONS], []) |
|
|
|
warnings.filterwarnings('ignore') |
|
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl") |
|
if not osp.exists(cache_file): |
|
clips = VideoClips(files, sequence_length, num_workers=4) |
|
try: |
|
pickle.dump(clips.metadata, open(cache_file, 'wb')) |
|
except: |
|
print(f"Failed to save metadata to {cache_file}") |
|
else: |
|
metadata = pickle.load(open(cache_file, 'rb')) |
|
clips = VideoClips(files, sequence_length, |
|
_precomputed_metadata=metadata) |
|
|
|
self._clips = clips |
|
|
|
self._clips.get_clip_location = self.get_random_clip_from_video |
|
|
|
def get_random_clip_from_video(self, idx: int) -> tuple: |
|
''' |
|
Sample a random clip starting index from the video. |
|
|
|
Args: |
|
idx: Index of the video. |
|
''' |
|
|
|
while self._clips.clips[idx].shape[0] <= 0: |
|
idx += 1 |
|
n_clip = self._clips.clips[idx].shape[0] |
|
clip_id = random.randint(0, n_clip - 1) |
|
return idx, clip_id |
|
|
|
def __len__(self): |
|
return self._clips.num_videos() |
|
|
|
def __getitem__(self, idx): |
|
resolution = self.resolution |
|
while True: |
|
try: |
|
video, _, _, idx = self._clips.get_clip(idx) |
|
except Exception as e: |
|
print(idx, e) |
|
idx = (idx + 1) % self._clips.num_clips() |
|
continue |
|
break |
|
|
|
return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames)) |
|
|
|
|
|
class FrameDataset(data.Dataset): |
|
""" |
|
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders |
|
in the provided directory. Each leaf folder is assumed to contain frames from a single video. |
|
|
|
Args: |
|
data_folder: path to the folder with video frames. The folder |
|
should contain folders with frames from each video. |
|
sequence_length: length of extracted video sequences |
|
resolution: resolution of the returned videos |
|
sample_every_n_frames: sample every n frames from the video |
|
""" |
|
|
|
def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1): |
|
self.resolution = resolution |
|
self.sequence_length = sequence_length |
|
self.sample_every_n_frames = sample_every_n_frames |
|
self.data_all = self.load_video_frames(data_folder) |
|
self.video_num = len(self.data_all) |
|
|
|
def __getitem__(self, index): |
|
batch_data = self.getTensor(index) |
|
return_list = {'video': batch_data} |
|
|
|
return return_list |
|
|
|
def load_video_frames(self, dataroot: str) -> list: |
|
''' |
|
Loads all the video frames under the dataroot and returns a list of all the video frames. |
|
|
|
Args: |
|
dataroot: The root directory containing the video frames. |
|
|
|
Returns: |
|
A list of all the video frames. |
|
|
|
''' |
|
data_all = [] |
|
frame_list = os.walk(dataroot) |
|
for _, meta in enumerate(frame_list): |
|
root = meta[0] |
|
try: |
|
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1])) |
|
except: |
|
print(meta[0], meta[2]) |
|
if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames): |
|
continue |
|
frames = [ |
|
os.path.join(root, item) for item in frames |
|
if is_image_file(item) |
|
] |
|
if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames): |
|
data_all.append(frames) |
|
|
|
return data_all |
|
|
|
def getTensor(self, index: int) -> torch.Tensor: |
|
''' |
|
Returns a tensor of the video frames at the given index. |
|
|
|
Args: |
|
index: The index of the video frames to return. |
|
|
|
Returns: |
|
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index. |
|
|
|
''' |
|
video = self.data_all[index] |
|
video_len = len(video) |
|
|
|
|
|
if self.sequence_length == -1: |
|
assert self.sample_every_n_frames == 1 |
|
start_idx = 0 |
|
end_idx = video_len |
|
else: |
|
n_frames_interval = self.sequence_length * self.sample_every_n_frames |
|
start_idx = random.randint(0, video_len - n_frames_interval) |
|
end_idx = start_idx + n_frames_interval |
|
img = Image.open(video[0]) |
|
h, w = img.height, img.width |
|
|
|
if h > w: |
|
half = (h - w) // 2 |
|
cropsize = (0, half, w, half + w) |
|
elif w > h: |
|
half = (w - h) // 2 |
|
cropsize = (half, 0, half + h, h) |
|
|
|
images = [] |
|
for i in range(start_idx, end_idx, |
|
self.sample_every_n_frames): |
|
path = video[i] |
|
img = Image.open(path) |
|
|
|
if h != w: |
|
img = img.crop(cropsize) |
|
|
|
img = img.resize( |
|
(self.resolution, self.resolution), |
|
Image.ANTIALIAS) |
|
img = np.asarray(img, dtype=np.float32) |
|
img /= 255. |
|
img_tensor = preprocess_image(img).unsqueeze(0) |
|
images.append(img_tensor) |
|
|
|
video_clip = torch.cat(images).permute(3, 0, 1, 2) |
|
return video_clip |
|
|
|
def __len__(self): |
|
return self.video_num |
|
|