File size: 4,781 Bytes
8390f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
r""" SPair-71k dataset """

import json
import glob
import os

import torch.nn.functional as F
import torch
from PIL import Image
import numpy as np

from .dataset import CorrespondenceDataset


class SPairDataset(CorrespondenceDataset):

    def __init__(self, benchmark, datapath, thres, split):
        r""" SPair-71k dataset constructor """
        super(SPairDataset, self).__init__(benchmark, datapath, thres, split)

        self.train_data = open(self.spt_path).read().split('\n')
        self.train_data = self.train_data[:len(self.train_data) - 1]
        self.src_imnames = list(map(lambda x: x.split('-')[1] + '.jpg', self.train_data))
        self.trg_imnames = list(map(lambda x: x.split('-')[2].split(':')[0] + '.jpg', self.train_data))
        self.seg_path = os.path.abspath(os.path.join(self.img_path, os.pardir, 'Segmentation'))
        self.cls = os.listdir(self.img_path)
        self.cls.sort()

        anntn_files = []
        for data_name in self.train_data:
            anntn_files.append(glob.glob('%s/%s.json' % (self.ann_path, data_name))[0])
        anntn_files = list(map(lambda x: json.load(open(x)), anntn_files))
        self.src_kps = list(map(lambda x: torch.tensor(x['src_kps']).t().float(), anntn_files))
        self.trg_kps = list(map(lambda x: torch.tensor(x['trg_kps']).t().float(), anntn_files))
        self.src_bbox = list(map(lambda x: torch.tensor(x['src_bndbox']).float(), anntn_files))
        self.trg_bbox = list(map(lambda x: torch.tensor(x['trg_bndbox']).float(), anntn_files))
        self.cls_ids = list(map(lambda x: self.cls.index(x['category']), anntn_files))

        self.vpvar = list(map(lambda x: torch.tensor(x['viewpoint_variation']), anntn_files))
        self.scvar = list(map(lambda x: torch.tensor(x['scale_variation']), anntn_files))
        self.trncn = list(map(lambda x: torch.tensor(x['truncation']), anntn_files))
        self.occln = list(map(lambda x: torch.tensor(x['occlusion']), anntn_files))

    def __getitem__(self, idx):
        r""" Construct and return a batch for SPair-71k dataset """
        sample = super(SPairDataset, self).__getitem__(idx)

        sample['src_mask'] = self.get_mask(sample, sample['src_imname'])
        sample['trg_mask'] = self.get_mask(sample, sample['trg_imname'])

        sample['src_bbox'] = self.get_bbox(self.src_bbox, idx, sample['src_imsize'])
        sample['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, sample['trg_imsize'])
        sample['pckthres'] = self.get_pckthres(sample,  sample['trg_imsize'])

        sample['vpvar'] = self.vpvar[idx]
        sample['scvar'] = self.scvar[idx]
        sample['trncn'] = self.trncn[idx]
        sample['occln'] = self.occln[idx]

        return sample

    def get_mask(self, sample, imname):
        mask_path = os.path.join(self.seg_path, sample['category'], imname.split('.')[0] + '.png')

        tensor_mask = torch.tensor(np.array(Image.open(mask_path)))

        class_dict = {'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4,
                      'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9,
                      'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14,
                      'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19}

        class_id = class_dict[sample['category']] + 1
        tensor_mask[tensor_mask != class_id] = 0
        tensor_mask[tensor_mask == class_id] = 255

        tensor_mask = F.interpolate(tensor_mask.unsqueeze(0).unsqueeze(0).float(),
                                    size=(self.img_size, self.img_size),
                                    mode='bilinear', align_corners=True).int().squeeze()

        return tensor_mask

    def get_image(self, img_names, idx):
        r""" Return image tensor """
        path = os.path.join(self.img_path, self.cls[self.cls_ids[idx]], img_names[idx])

        return Image.open(path).convert('RGB')

    def get_pckthres(self, sample, imsize):
        r""" Compute PCK threshold """
        return super(SPairDataset, self).get_pckthres(sample, imsize)

    def get_points(self, pts_list, idx, imsize):
        r""" Return key-points of an image """
        return super(SPairDataset, self).get_points(pts_list, idx, imsize)

    def match_idx(self, kps, n_pts):
        r""" Sample the nearst feature (receptive field) indices """
        return super(SPairDataset, self).match_idx(kps, n_pts)

    def get_bbox(self, bbox_list, idx, imsize):
        r""" Return object bounding-box """
        bbox = bbox_list[idx].clone()
        bbox[0::2] *= (self.img_size / imsize[0])
        bbox[1::2] *= (self.img_size / imsize[1])
        return bbox