|
r""" SPair-71k dataset """
|
|
|
|
import json
|
|
import glob
|
|
import os
|
|
|
|
import torch.nn.functional as F
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
from .dataset import CorrespondenceDataset
|
|
|
|
|
|
class SPairDataset(CorrespondenceDataset):
|
|
|
|
def __init__(self, benchmark, datapath, thres, split):
|
|
r""" SPair-71k dataset constructor """
|
|
super(SPairDataset, self).__init__(benchmark, datapath, thres, split)
|
|
|
|
self.train_data = open(self.spt_path).read().split('\n')
|
|
self.train_data = self.train_data[:len(self.train_data) - 1]
|
|
self.src_imnames = list(map(lambda x: x.split('-')[1] + '.jpg', self.train_data))
|
|
self.trg_imnames = list(map(lambda x: x.split('-')[2].split(':')[0] + '.jpg', self.train_data))
|
|
self.seg_path = os.path.abspath(os.path.join(self.img_path, os.pardir, 'Segmentation'))
|
|
self.cls = os.listdir(self.img_path)
|
|
self.cls.sort()
|
|
|
|
anntn_files = []
|
|
for data_name in self.train_data:
|
|
anntn_files.append(glob.glob('%s/%s.json' % (self.ann_path, data_name))[0])
|
|
anntn_files = list(map(lambda x: json.load(open(x)), anntn_files))
|
|
self.src_kps = list(map(lambda x: torch.tensor(x['src_kps']).t().float(), anntn_files))
|
|
self.trg_kps = list(map(lambda x: torch.tensor(x['trg_kps']).t().float(), anntn_files))
|
|
self.src_bbox = list(map(lambda x: torch.tensor(x['src_bndbox']).float(), anntn_files))
|
|
self.trg_bbox = list(map(lambda x: torch.tensor(x['trg_bndbox']).float(), anntn_files))
|
|
self.cls_ids = list(map(lambda x: self.cls.index(x['category']), anntn_files))
|
|
|
|
self.vpvar = list(map(lambda x: torch.tensor(x['viewpoint_variation']), anntn_files))
|
|
self.scvar = list(map(lambda x: torch.tensor(x['scale_variation']), anntn_files))
|
|
self.trncn = list(map(lambda x: torch.tensor(x['truncation']), anntn_files))
|
|
self.occln = list(map(lambda x: torch.tensor(x['occlusion']), anntn_files))
|
|
|
|
def __getitem__(self, idx):
|
|
r""" Construct and return a batch for SPair-71k dataset """
|
|
sample = super(SPairDataset, self).__getitem__(idx)
|
|
|
|
sample['src_mask'] = self.get_mask(sample, sample['src_imname'])
|
|
sample['trg_mask'] = self.get_mask(sample, sample['trg_imname'])
|
|
|
|
sample['src_bbox'] = self.get_bbox(self.src_bbox, idx, sample['src_imsize'])
|
|
sample['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, sample['trg_imsize'])
|
|
sample['pckthres'] = self.get_pckthres(sample, sample['trg_imsize'])
|
|
|
|
sample['vpvar'] = self.vpvar[idx]
|
|
sample['scvar'] = self.scvar[idx]
|
|
sample['trncn'] = self.trncn[idx]
|
|
sample['occln'] = self.occln[idx]
|
|
|
|
return sample
|
|
|
|
def get_mask(self, sample, imname):
|
|
mask_path = os.path.join(self.seg_path, sample['category'], imname.split('.')[0] + '.png')
|
|
|
|
tensor_mask = torch.tensor(np.array(Image.open(mask_path)))
|
|
|
|
class_dict = {'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4,
|
|
'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9,
|
|
'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14,
|
|
'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19}
|
|
|
|
class_id = class_dict[sample['category']] + 1
|
|
tensor_mask[tensor_mask != class_id] = 0
|
|
tensor_mask[tensor_mask == class_id] = 255
|
|
|
|
tensor_mask = F.interpolate(tensor_mask.unsqueeze(0).unsqueeze(0).float(),
|
|
size=(self.img_size, self.img_size),
|
|
mode='bilinear', align_corners=True).int().squeeze()
|
|
|
|
return tensor_mask
|
|
|
|
def get_image(self, img_names, idx):
|
|
r""" Return image tensor """
|
|
path = os.path.join(self.img_path, self.cls[self.cls_ids[idx]], img_names[idx])
|
|
|
|
return Image.open(path).convert('RGB')
|
|
|
|
def get_pckthres(self, sample, imsize):
|
|
r""" Compute PCK threshold """
|
|
return super(SPairDataset, self).get_pckthres(sample, imsize)
|
|
|
|
def get_points(self, pts_list, idx, imsize):
|
|
r""" Return key-points of an image """
|
|
return super(SPairDataset, self).get_points(pts_list, idx, imsize)
|
|
|
|
def match_idx(self, kps, n_pts):
|
|
r""" Sample the nearst feature (receptive field) indices """
|
|
return super(SPairDataset, self).match_idx(kps, n_pts)
|
|
|
|
def get_bbox(self, bbox_list, idx, imsize):
|
|
r""" Return object bounding-box """
|
|
bbox = bbox_list[idx].clone()
|
|
bbox[0::2] *= (self.img_size / imsize[0])
|
|
bbox[1::2] *= (self.img_size / imsize[1])
|
|
return bbox
|
|
|