taesiri's picture
Initial Commit
8390f90
r""" Superclass for semantic correspondence datasets """
import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
from model.base.geometry import Geometry
class CorrespondenceDataset(Dataset):
r""" Parent class of PFPascal, PFWillow, and SPair """
def __init__(self, benchmark, datapath, thres, split):
r""" CorrespondenceDataset constructor """
super(CorrespondenceDataset, self).__init__()
# {Directory name, Layout path, Image path, Annotation path, PCK threshold}
self.metadata = {
'pfwillow': ('PF-WILLOW',
'test_pairs.csv',
'',
'',
'bbox'),
'pfpascal': ('PF-PASCAL',
'_pairs.csv',
'JPEGImages',
'Annotations',
'img'),
'spair': ('SPair-71k',
'Layout/large',
'JPEGImages',
'PairAnnotation',
'bbox')
}
# Directory path for train, val, or test splits
base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0])
if benchmark == 'pfpascal':
self.spt_path = os.path.join(base_path, split+'_pairs.csv')
elif benchmark == 'spair':
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt')
else:
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1])
# Directory path for images
self.img_path = os.path.join(base_path, self.metadata[benchmark][2])
# Directory path for annotations
if benchmark == 'spair':
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split)
else:
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3])
# Miscellaneous
self.max_pts = 40
self.split = split
self.img_size = Geometry.img_size
self.benchmark = benchmark
self.range_ts = torch.arange(self.max_pts)
self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres
self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# To get initialized in subclass constructors
self.train_data = []
self.src_imnames = []
self.trg_imnames = []
self.cls = []
self.cls_ids = []
self.src_kps = []
self.trg_kps = []
def __len__(self):
r""" Returns the number of pairs """
return len(self.train_data)
def __getitem__(self, idx):
r""" Constructs and return a batch """
# Image name
batch = dict()
batch['src_imname'] = self.src_imnames[idx]
batch['trg_imname'] = self.trg_imnames[idx]
# Object category
batch['category_id'] = self.cls_ids[idx]
batch['category'] = self.cls[batch['category_id']]
# Image as numpy (original width, original height)
src_pil = self.get_image(self.src_imnames, idx)
trg_pil = self.get_image(self.trg_imnames, idx)
batch['src_imsize'] = src_pil.size
batch['trg_imsize'] = trg_pil.size
# Image as tensor
batch['src_img'] = self.transform(src_pil)
batch['trg_img'] = self.transform(trg_pil)
# Key-points (re-scaled)
batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size)
batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size)
batch['n_pts'] = torch.tensor(num_pts)
# Total number of pairs in training split
batch['datalen'] = len(self.train_data)
return batch
def get_image(self, imnames, idx):
r""" Reads PIL image from path """
path = os.path.join(self.img_path, imnames[idx])
return Image.open(path).convert('RGB')
def get_pckthres(self, batch, imsize):
r""" Computes PCK threshold """
if self.thres == 'bbox':
bbox = batch['trg_bbox'].clone()
bbox_w = (bbox[2] - bbox[0])
bbox_h = (bbox[3] - bbox[1])
pckthres = torch.max(bbox_w, bbox_h)
elif self.thres == 'img':
imsize_t = batch['trg_img'].size()
pckthres = torch.tensor(max(imsize_t[1], imsize_t[2]))
else:
raise Exception('Invalid pck threshold type: %s' % self.thres)
return pckthres.float()
def get_points(self, pts_list, idx, org_imsize):
r""" Returns key-points of an image """
xy, n_pts = pts_list[idx].size()
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0])
y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1])
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
return kps, n_pts