|
r""" Superclass for semantic correspondence datasets """
|
|
|
|
import os
|
|
|
|
from torch.utils.data import Dataset
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import torch
|
|
|
|
from model.base.geometry import Geometry
|
|
|
|
|
|
class CorrespondenceDataset(Dataset):
|
|
r""" Parent class of PFPascal, PFWillow, and SPair """
|
|
def __init__(self, benchmark, datapath, thres, split):
|
|
r""" CorrespondenceDataset constructor """
|
|
super(CorrespondenceDataset, self).__init__()
|
|
|
|
|
|
self.metadata = {
|
|
'pfwillow': ('PF-WILLOW',
|
|
'test_pairs.csv',
|
|
'',
|
|
'',
|
|
'bbox'),
|
|
'pfpascal': ('PF-PASCAL',
|
|
'_pairs.csv',
|
|
'JPEGImages',
|
|
'Annotations',
|
|
'img'),
|
|
'spair': ('SPair-71k',
|
|
'Layout/large',
|
|
'JPEGImages',
|
|
'PairAnnotation',
|
|
'bbox')
|
|
}
|
|
|
|
|
|
base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0])
|
|
if benchmark == 'pfpascal':
|
|
self.spt_path = os.path.join(base_path, split+'_pairs.csv')
|
|
elif benchmark == 'spair':
|
|
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt')
|
|
else:
|
|
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1])
|
|
|
|
|
|
self.img_path = os.path.join(base_path, self.metadata[benchmark][2])
|
|
|
|
|
|
if benchmark == 'spair':
|
|
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split)
|
|
else:
|
|
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3])
|
|
|
|
|
|
self.max_pts = 40
|
|
self.split = split
|
|
self.img_size = Geometry.img_size
|
|
self.benchmark = benchmark
|
|
self.range_ts = torch.arange(self.max_pts)
|
|
self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres
|
|
self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])])
|
|
|
|
|
|
self.train_data = []
|
|
self.src_imnames = []
|
|
self.trg_imnames = []
|
|
self.cls = []
|
|
self.cls_ids = []
|
|
self.src_kps = []
|
|
self.trg_kps = []
|
|
|
|
def __len__(self):
|
|
r""" Returns the number of pairs """
|
|
return len(self.train_data)
|
|
|
|
def __getitem__(self, idx):
|
|
r""" Constructs and return a batch """
|
|
|
|
|
|
batch = dict()
|
|
batch['src_imname'] = self.src_imnames[idx]
|
|
batch['trg_imname'] = self.trg_imnames[idx]
|
|
|
|
|
|
batch['category_id'] = self.cls_ids[idx]
|
|
batch['category'] = self.cls[batch['category_id']]
|
|
|
|
|
|
src_pil = self.get_image(self.src_imnames, idx)
|
|
trg_pil = self.get_image(self.trg_imnames, idx)
|
|
batch['src_imsize'] = src_pil.size
|
|
batch['trg_imsize'] = trg_pil.size
|
|
|
|
|
|
batch['src_img'] = self.transform(src_pil)
|
|
batch['trg_img'] = self.transform(trg_pil)
|
|
|
|
|
|
batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size)
|
|
batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size)
|
|
batch['n_pts'] = torch.tensor(num_pts)
|
|
|
|
|
|
batch['datalen'] = len(self.train_data)
|
|
|
|
return batch
|
|
|
|
def get_image(self, imnames, idx):
|
|
r""" Reads PIL image from path """
|
|
path = os.path.join(self.img_path, imnames[idx])
|
|
return Image.open(path).convert('RGB')
|
|
|
|
def get_pckthres(self, batch, imsize):
|
|
r""" Computes PCK threshold """
|
|
if self.thres == 'bbox':
|
|
bbox = batch['trg_bbox'].clone()
|
|
bbox_w = (bbox[2] - bbox[0])
|
|
bbox_h = (bbox[3] - bbox[1])
|
|
pckthres = torch.max(bbox_w, bbox_h)
|
|
elif self.thres == 'img':
|
|
imsize_t = batch['trg_img'].size()
|
|
pckthres = torch.tensor(max(imsize_t[1], imsize_t[2]))
|
|
else:
|
|
raise Exception('Invalid pck threshold type: %s' % self.thres)
|
|
return pckthres.float()
|
|
|
|
def get_points(self, pts_list, idx, org_imsize):
|
|
r""" Returns key-points of an image """
|
|
xy, n_pts = pts_list[idx].size()
|
|
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
|
|
x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0])
|
|
y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1])
|
|
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
|
|
|
|
return kps, n_pts
|
|
|