from torchvision import transforms from datasets import video_transforms from .ucf101_datasets import UCF101 from .dummy_datasets import DummyDataset from .webvid_datasets import WebVid10M from .videoswap_datasets import VideoSwapDataset from .dl3dv_datasets import DL3DVDataset from .pair_datasets import PairDataset from .metric_datasets import MetricDataset from .sakuga_ref_datasets import SakugaRefDataset def get_dataset(args): if args.dataset not in ["encdec_images", "pair_dataset"]: temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) # 16 1 if args.dataset == 'sakuga_ref': temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval+args.ref_jump_frames) # 16 1 if args.dataset == 'ucf101': transform_ucf101 = transforms.Compose([ video_transforms.ToTensorVideo(), # TCHW video_transforms.RandomHorizontalFlipVideo(), video_transforms.UCFCenterCropVideo(args.image_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) ]) dataset = UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample) return dataset elif args.dataset == 'dummy': size = (args.height, args.width) transform = transforms.Compose([ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), # NOTE video_transforms.UCFCenterCropVideo(size=size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) ]) dataset = DummyDataset( sample_frames=args.num_frames, base_folder=args.base_folder, temporal_sample=temporal_sample, transform=transform, seed=args.seed, file_list=args.file_list, ) return dataset elif args.dataset == 'sakuga_ref': size = (args.height, args.width) transform = transforms.Compose([ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), # NOTE video_transforms.UCFCenterCropVideo(size=size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) ]) dataset = SakugaRefDataset( video_frames=args.num_frames, ref_jump_frames=args.ref_jump_frames, base_folder=args.base_folder, temporal_sample=temporal_sample, transform=transform, seed=args.seed, file_list=args.file_list, ) return dataset elif args.dataset == 'webvid': size = (args.height, args.width) transform = transforms.Compose([ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), # NOTE video_transforms.UCFCenterCropVideo(size=size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) ]) dataset = WebVid10M( sample_frames=args.num_frames, base_folder=args.base_folder, temporal_sample=temporal_sample, transform=transform, seed=args.seed, ) return dataset elif args.dataset == 'videoswap': size = (args.height, args.width) transform = transforms.Compose([ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), # video_transforms.UCFCenterCropVideo(size=size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) ]) dataset = VideoSwapDataset( width=args.width, height=args.height, sample_frames=args.num_frames, base_folder=args.base_folder, temporal_sample=temporal_sample, transform=transform, seed=args.seed ) return dataset elif args.dataset == 'dl3dv': size = (args.height, args.width) # transform = transforms.Compose([ # video_transforms.ToTensorVideo(), # TCHW # # video_transforms.RandomHorizontalFlipVideo(), # # video_transforms.UCFCenterCropVideo(size=size), # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) # ]) dataset = DL3DVDataset( width=args.width, height=args.height, sample_frames=args.num_frames, base_folder=args.base_folder, file_list=args.file_list, temporal_sample=temporal_sample, # transform=transform, seed=args.seed, ) return dataset elif args.dataset == "pair_dataset": # size = (args.height, args.width) # transform = transforms.Compose([ # video_transforms.ToTensorVideo(), # TCHW # # video_transforms.RandomHorizontalFlipVideo(), # video_transforms.UCFCenterCropVideo(size=size), # # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) # ]) dataset = PairDataset( # width=args.width, # height=args.height, # sample_frames=args.num_frames, base_folder=args.base_folder, # temporal_sample=temporal_sample, # transform=transform, # seed=args.seed, with_pair=args.with_pair, ) return dataset elif args.dataset == "metric_dataset": dataset = MetricDataset( base_folder=args.base_folder, ) return dataset else: raise NotImplementedError(args.dataset)