|
import os |
|
from torch.utils.data import Dataset |
|
from torch.utils.data import DataLoader |
|
|
|
class VideosDataset(Dataset): |
|
def __init__(self, data_dir): |
|
self.data_dir = data_dir |
|
self.video_folders = [os.path.join(data_dir, folder) for folder in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, folder))] |
|
|
|
def __len__(self): |
|
return len(self.video_folders) |
|
|
|
def __getitem__(self, idx): |
|
video_folder_path = self.video_folders[idx] |
|
return video_folder_path |
|
|
|
def create_video_dataloader(opt): |
|
|
|
data_dir = opt.videos_dataset.dataset_dir |
|
|
|
video_dataset = VideosDataset(data_dir) |
|
|
|
|
|
batch_size = opt.videos_dataset.batch_size |
|
shuffle = opt.videos_dataset.shuffle |
|
num_workers = opt.videos_dataset.num_workers |
|
|
|
data_loader = DataLoader(video_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) |
|
|
|
return data_loader |
|
|