taesiri's picture
Initial Commit
8390f90
raw
history blame
4.6 kB
r""" PF-PASCAL dataset """
import os
import scipy.io as sio
import pandas as pd
import numpy as np
import torch
from .dataset import CorrespondenceDataset
class PFPascalDataset(CorrespondenceDataset):
def __init__(self, benchmark, datapath, thres, split):
r""" PF-PASCAL dataset constructor """
super(PFPascalDataset, self).__init__(benchmark, datapath, thres, split)
self.train_data = pd.read_csv(self.spt_path)
self.src_imnames = np.array(self.train_data.iloc[:, 0])
self.trg_imnames = np.array(self.train_data.iloc[:, 1])
self.cls = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1
if split == 'trn':
self.flip = self.train_data.iloc[:, 3].values.astype('int')
self.src_kps = []
self.trg_kps = []
self.src_bbox = []
self.trg_bbox = []
for src_imname, trg_imname, cls in zip(self.src_imnames, self.trg_imnames, self.cls_ids):
src_anns = os.path.join(self.ann_path, self.cls[cls],
os.path.basename(src_imname))[:-4] + '.mat'
trg_anns = os.path.join(self.ann_path, self.cls[cls],
os.path.basename(trg_imname))[:-4] + '.mat'
src_kp = torch.tensor(read_mat(src_anns, 'kps')).float()
trg_kp = torch.tensor(read_mat(trg_anns, 'kps')).float()
src_box = torch.tensor(read_mat(src_anns, 'bbox')[0].astype(float))
trg_box = torch.tensor(read_mat(trg_anns, 'bbox')[0].astype(float))
src_kps = []
trg_kps = []
for src_kk, trg_kk in zip(src_kp, trg_kp):
if len(torch.isnan(src_kk).nonzero()) != 0 or \
len(torch.isnan(trg_kk).nonzero()) != 0:
continue
else:
src_kps.append(src_kk)
trg_kps.append(trg_kk)
self.src_kps.append(torch.stack(src_kps).t())
self.trg_kps.append(torch.stack(trg_kps).t())
self.src_bbox.append(src_box)
self.trg_bbox.append(trg_box)
self.src_imnames = list(map(lambda x: os.path.basename(x), self.src_imnames))
self.trg_imnames = list(map(lambda x: os.path.basename(x), self.trg_imnames))
def __getitem__(self, idx):
r""" Constructs and returns a batch for PF-PASCAL dataset """
batch = super(PFPascalDataset, self).__getitem__(idx)
# Object bounding-box (resized following self.img_size)
batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize'])
batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize'])
batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize'])
# Horizontal flipping key-points during training
if self.split == 'trn' and self.flip[idx]:
self.horizontal_flip(batch)
batch['flip'] = 1
else:
batch['flip'] = 0
return batch
def get_bbox(self, bbox_list, idx, imsize):
r""" Returns object bounding-box """
bbox = bbox_list[idx].clone()
bbox[0::2] *= (self.img_size / imsize[0])
bbox[1::2] *= (self.img_size / imsize[1])
return bbox
def horizontal_flip(self, batch):
tmp = batch['src_bbox'][0].clone()
batch['src_bbox'][0] = batch['src_img'].size(2) - batch['src_bbox'][2]
batch['src_bbox'][2] = batch['src_img'].size(2) - tmp
tmp = batch['trg_bbox'][0].clone()
batch['trg_bbox'][0] = batch['trg_img'].size(2) - batch['trg_bbox'][2]
batch['trg_bbox'][2] = batch['trg_img'].size(2) - tmp
batch['src_kps'][0][:batch['n_pts']] = batch['src_img'].size(2) - batch['src_kps'][0][:batch['n_pts']]
batch['trg_kps'][0][:batch['n_pts']] = batch['trg_img'].size(2) - batch['trg_kps'][0][:batch['n_pts']]
batch['src_img'] = torch.flip(batch['src_img'], dims=(2,))
batch['trg_img'] = torch.flip(batch['trg_img'], dims=(2,))
def read_mat(path, obj_name):
r""" Reads specified objects from Matlab data file. (.mat) """
mat_contents = sio.loadmat(path)
mat_obj = mat_contents[obj_name]
return mat_obj