torchnet / dataset_test.py
milselarch's picture
push to main
df07554
import options as opt
import matplotlib.pyplot as plt
import torch.optim as optim
import numpy as np
import time
from dataset import GridDataset
from torch.utils.data import DataLoader
def dataset2dataloader(
dataset, num_workers=opt.num_workers, shuffle=True
):
return DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=False
)
dataset = GridDataset(
video_path=opt.video_path,
alignments_dir=opt.alignments_dir,
file_list=opt.train_list,
vid_pad=opt.vid_padding,
image_dir=opt.images_dir,
txt_pad=opt.txt_padding,
phase='train'
)
loader = dataset2dataloader(dataset)
def fetch_samples(num_samples=10):
samples = []
sample_no = 0
for sample in loader:
sample_no += 1
samples.append(sample)
if sample_no >= num_samples:
break
return samples
samples = fetch_samples()
print(samples[0])
print('END')