import os from pathlib import Path import random import numpy as np import pickle as pk import cv2 from tqdm import tqdm from PIL import Image import torchvision.transforms as transforms import torch # from prefetch_generator import BackgroundGenerator from torch.utils.data import DataLoader, Dataset class VideoDataset(Dataset): def __init__(self, directory_list, local_rank=0, enable_GPUs_num=0, distributed_load=False, resize_shape=[224, 224] , mode='train', clip_len=32, crop_size = 168): self.clip_len, self.crop_size, self.resize_shape = clip_len, crop_size, resize_shape self.mode = mode self.fnames, labels = [],[] # get the directory of the specified split for directory in directory_list: folder = Path(directory) print("Load dataset from folder : ", folder) for label in sorted(os.listdir(folder)): for fname in os.listdir(os.path.join(folder, label)) if mode=="train" else os.listdir(os.path.join(folder, label))[:10]: self.fnames.append(os.path.join(folder, label, fname)) labels.append(label) # print(labels) random_list = list(zip(self.fnames, labels)) random.shuffle(random_list) self.fnames[:], labels[:] = zip(*random_list) self.labels = labels # self.fnames = self.fnames[:240] if mode == 'train' and distributed_load: single_num_ = len(self.fnames)//enable_GPUs_num self.fnames = self.fnames[local_rank*single_num_:((local_rank+1)*single_num_)] labels = labels[local_rank*single_num_:((local_rank+1)*single_num_)] # prepare a mapping between the label names (strings) and indices (ints) self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))} # convert the list of label names into an array of label indices self.label_array = np.array([self.label2index[label] for label in labels], dtype=int) def __getitem__(self, index): # loading and preprocessing. TODO move them to transform classess buffer = self.loadvideo(self.fnames[index]) height_index = np.random.randint(buffer.shape[2] - self.crop_size) width_index = np.random.randint(buffer.shape[3] - self.crop_size) return buffer[:,:,height_index:height_index + self.crop_size, width_index:width_index + self.crop_size], self.label_array[index] def __len__(self): return len(self.fnames) def loadvideo(self, fname): # initialize a VideoCapture object to read video data into a numpy array self.transform = transforms.Compose([ transforms.Resize([self.resize_shape[0], self.resize_shape[1]]), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) flip, flipCode = 1, random.choice([-1,0,1]) if np.random.random() < 0.5 and self.mode=="train" else 0 try: video_stream = cv2.VideoCapture(fname) frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT)) except RuntimeError: index = np.random.randint(self.__len__()) video_stream = cv2.VideoCapture(self.fnames[index]) frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT)) while frame_count self.clip_len*2+2 else 1 time_index = np.random.randint(frame_count - self.clip_len * speed_rate) start_idx, end_idx, final_idx = time_index, time_index+(self.clip_len*speed_rate), frame_count-1 count, sample_count, retaining = 0, 0, True # create a buffer. Must have dtype float, so it gets converted to a FloatTensor by Pytorch later buffer = np.empty((self.clip_len, 3, self.resize_shape[0], self.resize_shape[1]), np.dtype('float32')) while (count <= end_idx and retaining): retaining, frame = video_stream.read() if count < start_idx: count += 1 continue if count % speed_rate == speed_rate-1 and count >= start_idx and sample_count < self.clip_len: if flip: frame = cv2.flip(frame, flipCode=flipCode) try: buffer[sample_count] = self.transform(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) except cv2.error as err: continue sample_count += 1 count += 1 video_stream.release() return buffer.transpose((1, 0, 2, 3)) if __name__ == '__main__': datapath = ['/data/datasets/ucf101/videos'] dataset = VideoDataset(datapath, resize_shape=[224, 224], mode='validation') x, y = dataset[0] # x: (3, num_frames, w, h) print(x.shape, y.shape, y) # dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=24, pin_memory=True) # bar = tqdm(total=len(dataloader), ncols=80) # prefetcher = DataPrefetcher(BackgroundGenerator(dataloader), 0) # batch = prefetcher.next() # iter_id = 0 # while batch is not None: # iter_id += 1 # bar.update(1) # if iter_id >= len(dataloader): # break # batch = prefetcher.next() # print(batch[0].shape) # print("label: ", batch[1]) # ''' # for step, (buffer, labels) in enumerate(BackgroundGenerator(dataloader)): # print(buffer.shape) # print("label: ", labels) # bar.update(1) # '''