import torch,os from torch.utils.data.dataset import Dataset from PIL import Image import torchvision.transforms as transforms import re from easydict import EasyDict as edict def data_list(img_root,mode): data_list=[] if mode=='train': split_file=os.path.join(img_root, 'splits/train-19zl.csv') with open(split_file) as f: list = f.readlines() for i in list: aerial_name=re.split(r',', re.split('\n', i)[0])[0] panorama_name = re.split(r',', re.split('\n', i)[0])[1] data_list.append([aerial_name, panorama_name]) else: split_file=os.path.join(img_root+'splits/val-19zl.csv') with open(split_file) as f: list = f.readlines() for i in list: aerial_name=re.split(r',', re.split('\n', i)[0])[0] panorama_name = re.split(r',', re.split('\n', i)[0])[1] data_list.append([aerial_name, panorama_name]) print('length of dataset is: ', len(data_list)) return [os.path.join(img_root, i[1]) for i in data_list] def img_read(img,size=None,datatype='RGB'): img = Image.open(img).convert('RGB' if datatype=='RGB' else "L") if size: if type(size) is int: size = (size,size) img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST) img = transforms.ToTensor()(img) return img class Dataset(Dataset): def __init__(self, opt,split='train',sub=None,sty_img=None): self.pano_list = data_list(img_root=opt.data.root,mode=split) if sub: self.pano_list = self.pano_list[:sub] if opt.task == 'test_vid': demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img) self.pano_list = [demo_img_path] if sty_img: assert opt.sty_img.split('.')[-1] == 'jpg' demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img) self.pano_list = [demo_img_path] self.opt = opt def __len__(self): return len(self.pano_list) def __getitem__(self, index): pano = self.pano_list[index] aer = pano.replace('streetview/panos', 'bingmap/19') if self.opt.data.sky_mask: sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png') name = pano aer = img_read(aer, size = self.opt.data.sat_size) pano = img_read(pano,size = self.opt.data.pano_size) if self.opt.data.sky_mask: sky = img_read(sky,size=self.opt.data.pano_size,datatype='L') input = {} input['sat']=aer input['pano']=pano input['paths']=name if self.opt.data.sky_mask: input['sky_mask']=sky black_ground = torch.zeros_like(pano) if self.opt.data.histo_mode =='grey': input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] elif self.opt.data.histo_mode in ['rgb','RGB']: input_a = (pano*sky+black_ground*(1-sky)) for idx in range(len(input_a)): if idx == 0: sky_histc = input_a[idx].histc()[10:] else: sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0) input['sky_histc'] = sky_histc return input