import options as opt import matplotlib.pyplot as plt import torch.optim as optim import numpy as np import time from dataset import GridDataset from torch.utils.data import DataLoader def dataset2dataloader( dataset, num_workers=opt.num_workers, shuffle=True ): return DataLoader( dataset, batch_size=opt.batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False ) dataset = GridDataset( video_path=opt.video_path, alignments_dir=opt.alignments_dir, file_list=opt.train_list, vid_pad=opt.vid_padding, image_dir=opt.images_dir, txt_pad=opt.txt_padding, phase='train' ) loader = dataset2dataloader(dataset) def fetch_samples(num_samples=10): samples = [] sample_no = 0 for sample in loader: sample_no += 1 samples.append(sample) if sample_no >= num_samples: break return samples samples = fetch_samples() print(samples[0]) print('END')