File size: 991 Bytes
df07554 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import options as opt
import matplotlib.pyplot as plt
import torch.optim as optim
import numpy as np
import time
from dataset import GridDataset
from torch.utils.data import DataLoader
def dataset2dataloader(
dataset, num_workers=opt.num_workers, shuffle=True
):
return DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=False
)
dataset = GridDataset(
video_path=opt.video_path,
alignments_dir=opt.alignments_dir,
file_list=opt.train_list,
vid_pad=opt.vid_padding,
image_dir=opt.images_dir,
txt_pad=opt.txt_padding,
phase='train'
)
loader = dataset2dataloader(dataset)
def fetch_samples(num_samples=10):
samples = []
sample_no = 0
for sample in loader:
sample_no += 1
samples.append(sample)
if sample_no >= num_samples:
break
return samples
samples = fetch_samples()
print(samples[0])
print('END') |