File size: 3,542 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


import os
import glob
import numpy as np
import math
import torch
import torch.utils.data as data

import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from configs.anipose_data_info import COMPLETE_DATA_INFO        
from stacked_hourglass.utils.imutils import load_image, im_to_torch 
from stacked_hourglass.utils.transforms import crop, color_normalize
from stacked_hourglass.utils.pilutil import imresize 
from stacked_hourglass.utils.imutils import im_to_torch
from configs.data_info import COMPLETE_DATA_INFO_24


class ImgCrops(data.Dataset):
    DATA_INFO = COMPLETE_DATA_INFO_24
    ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16]  

    def __init__(self, image_list, bbox_list=None, inp_res=256, dataset_mode='keyp_only'):
        # the list contains the images directly, not only their paths
        self.image_list = image_list
        self.bbox_list = bbox_list
        self.inp_res = inp_res
        self.test_name_list = []
        for ind in np.arange(0, len(self.image_list)):
            self.test_name_list.append(str(ind))
        print('len(dataset): ' + str(self.__len__()))

    def __getitem__(self, index):
        '''img_name = self.test_name_list[index]
        # load image
        img_path = os.path.join(self.folder_imgs, img_name)
        img = load_image(img_path)  # CxHxW'''

        # load image
        '''img_hwc = self.image_list[index]
        img = np.rollaxis(img_hwc, 2, 0)  '''
        img = im_to_torch(self.image_list[index])

        # import pdb; pdb.set_trace()

        # try loading bounding box
        if self.bbox_list is not None:
            bbox = self.bbox_list[index]
            bbox_xywh = [bbox[0][0], bbox[0][1], bbox[1][0]-bbox[0][0], bbox[1][1]-bbox[0][1]]
            bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]]
            bbox_max = max(bbox_xywh[2], bbox_xywh[3])
            bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2)
            bbox_s = bbox_max / 200. * 256. / 200.  # maximum side of the bbox will be 200
            c = torch.Tensor(bbox_c)
            s = bbox_s
            img_prep = crop(img, c, s, [self.inp_res, self.inp_res], rot=0)

        else:

            # prepare image (cropping and color)
            img_max = max(img.shape[1], img.shape[2])
            img_padded = torch.zeros((img.shape[0], img_max, img_max))
            if img_max == img.shape[2]:
                start = (img_max-img.shape[1])//2
                img_padded[:, start:start+img.shape[1], :] = img
            else:
                start = (img_max-img.shape[2])//2
                img_padded[:, :, start:start+img.shape[2]] = img   
            img = img_padded
            img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear'))   
        
        inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev)
        # add the following fields to make it compatible with stanext, most of them are fake
        target_dict = {'index': index, 'center' : -2, 'scale' : -2, 
            'breed_index': -2, 'sim_breed_index': -2,
            'ind_dataset': 1}
        target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3))
        target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3))
        target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1))
        target_dict['silh'] = np.zeros((self.inp_res, self.inp_res))
        return inp, target_dict


    def __len__(self):
        return len(self.image_list)