File size: 2,444 Bytes
8390f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
r""" PF-WILLOW dataset """

import os

import pandas as pd
import numpy as np
import torch

from .dataset import CorrespondenceDataset


class PFWillowDataset(CorrespondenceDataset):

    def __init__(self, benchmark, datapath, thres, split):
        r"""PF-WILLOW dataset constructor"""
        super(PFWillowDataset, self).__init__(benchmark, datapath, thres, split)

        self.train_data = pd.read_csv(self.spt_path)
        self.src_imnames = np.array(self.train_data.iloc[:, 0])
        self.trg_imnames = np.array(self.train_data.iloc[:, 1])
        self.src_kps = self.train_data.iloc[:, 2:22].values
        self.trg_kps = self.train_data.iloc[:, 22:].values
        self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)',
                    'motorbike(G)', 'motorbike(M)', 'motorbike(S)',
                    'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)']
        self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames))
        self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
        self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))

    def __getitem__(self, idx):
        r""" Constructs and returns a batch for PF-WILLOW dataset """
        batch = super(PFWillowDataset, self).__getitem__(idx)
        batch['pckthres'] = self.get_pckthres(batch)

        return batch

    def get_pckthres(self, batch):
        r""" Computes PCK threshold """
        if self.thres == 'bbox':
            return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone()
        elif self.thres == 'img':
            return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2]))
        else:
            raise Exception('Invalid pck evaluation level: %s' % self.thres)

    def get_points(self, pts_list, idx, org_imsize):
        r""" Returns key-points of an image """
        point_coords = pts_list[idx, :].reshape(2, 10)
        point_coords = torch.tensor(point_coords.astype(np.float32))
        xy, n_pts = point_coords.size()
        pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
        x_crds = point_coords[0] * (self.img_size / org_imsize[0])
        y_crds = point_coords[1] * (self.img_size / org_imsize[1])
        kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)

        return kps, n_pts