r""" PF-WILLOW dataset """ import os import pandas as pd import numpy as np import torch from .dataset import CorrespondenceDataset class PFWillowDataset(CorrespondenceDataset): def __init__(self, benchmark, datapath, thres, split): r"""PF-WILLOW dataset constructor""" super(PFWillowDataset, self).__init__(benchmark, datapath, thres, split) self.train_data = pd.read_csv(self.spt_path) self.src_imnames = np.array(self.train_data.iloc[:, 0]) self.trg_imnames = np.array(self.train_data.iloc[:, 1]) self.src_kps = self.train_data.iloc[:, 2:22].values self.trg_kps = self.train_data.iloc[:, 22:].values self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)', 'motorbike(G)', 'motorbike(M)', 'motorbike(S)', 'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)'] self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames)) self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames)) self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames)) def __getitem__(self, idx): r""" Constructs and returns a batch for PF-WILLOW dataset """ batch = super(PFWillowDataset, self).__getitem__(idx) batch['pckthres'] = self.get_pckthres(batch) return batch def get_pckthres(self, batch): r""" Computes PCK threshold """ if self.thres == 'bbox': return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone() elif self.thres == 'img': return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2])) else: raise Exception('Invalid pck evaluation level: %s' % self.thres) def get_points(self, pts_list, idx, org_imsize): r""" Returns key-points of an image """ point_coords = pts_list[idx, :].reshape(2, 10) point_coords = torch.tensor(point_coords.astype(np.float32)) xy, n_pts = point_coords.size() pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2 x_crds = point_coords[0] * (self.img_size / org_imsize[0]) y_crds = point_coords[1] * (self.img_size / org_imsize[1]) kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1) return kps, n_pts