# 24 joints instead of 20!! import gzip import json import os import random import math import numpy as np import torch import torch.utils.data as data from importlib_resources import open_binary from scipy.io import loadmat from tabulate import tabulate import itertools import json from scipy import ndimage from csv import DictReader from pycocotools.mask import decode as decode_RLE import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from configs.data_info import COMPLETE_DATA_INFO_24 from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps from stacked_hourglass.utils.misc import to_torch from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform import stacked_hourglass.datasets.utils_stanext as utils_stanext from stacked_hourglass.utils.visualization import save_input_image_with_keypoints from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR class StanExt(data.Dataset): DATA_INFO = COMPLETE_DATA_INFO_24 # Suggested joints to use for keypoint reprojection error calculations ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, scale_factor=0.25, rot_factor=30, label_type='Gaussian', do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'): self.V12 = V12 self.is_train = is_train # training set or test set if do_augment == 'yes': self.do_augment = True elif do_augment == 'no': self.do_augment = False elif do_augment=='default': if self.is_train: self.do_augment = True else: self.do_augment = False else: raise ValueError self.inp_res = inp_res self.out_res = out_res self.sigma = sigma self.scale_factor = scale_factor self.rot_factor = rot_factor self.label_type = label_type self.dataset_mode = dataset_mode if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': self.calc_seg = True else: self.calc_seg = False self.val_opt = val_opt # create train/val split self.img_folder = utils_stanext.get_img_dir(V12=self.V12) self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) self.train_name_list = list(self.train_dict.keys()) # 7004 if self.val_opt == 'test': self.test_dict = init_test_dict self.test_name_list = list(self.test_dict.keys()) elif self.val_opt == 'val': self.test_dict = init_val_dict self.test_name_list = list(self.test_dict.keys()) else: raise NotImplementedError # stanext breed dict (contains for each name a stanext specific index) breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) self.train_name_list = sorted(self.train_name_list) self.test_name_list = sorted(self.test_name_list) random.seed(4) random.shuffle(self.train_name_list) random.shuffle(self.test_name_list) if shorten_dataset_to is not None: # sometimes it is useful to have a smaller set (validation speed, debugging) self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] # special case for debugging: 12 similar images if shorten_dataset_to == 12: my_sample = self.test_name_list[2] for ind in range(0, 12): self.test_name_list[ind] = my_sample print('len(dataset): ' + str(self.__len__())) # add results for eyes, whithers and throat as obtained through anipose -> they are used # as pseudo ground truth at training time. self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') def get_data_sampler_info(self): # for custom data sampler if self.is_train: name_list = self.train_name_list else: name_list = self.test_name_list info_dict = {'name_list': name_list, 'stanext_breed_dict': self.breed_dict, 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, 'breeds_summary': COMPLETE_SUMMARY_BREEDS, 'breeds_sim_martix_raw': SIM_MATRIX_RAW, 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES } return info_dict def get_breed_dict(self, breed_json_path, create_new_breed_json=False): if create_new_breed_json: breed_dict = {} breed_index = 0 for img_name in self.train_name_list: folder_name = img_name.split('/')[0] breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] if not (folder_name in breed_dict): breed_dict[folder_name] = { 'breed_name': breed_name, 'index': breed_index} breed_index += 1 with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) else: with open(breed_json_path) as json_file: breed_dict = json.load(json_file) return breed_dict def __getitem__(self, index): if self.is_train: name = self.train_name_list[index] data = self.train_dict[name] else: name = self.test_name_list[index] data = self.test_dict[name] sf = self.scale_factor rf = self.rot_factor img_path = os.path.join(self.img_folder, data['img_path']) try: anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) with open(anipose_res_path) as f: anipose_data = json.load(f) anipose_thr = 0.2 anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256 # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 c = torch.Tensor(bbox_c) s = bbox_s # For single-person pose estimation with a centered/scaled figure nparts = pts.size(0) img = load_image(img_path) # CxHxW # segmentation map (we reshape it to 3xHxW, such that we can do the # same transformations as with the image) if self.calc_seg: seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) seg = torch.cat(3*[seg]) r = 0 do_flip = False if self.do_augment: s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 # Flip if random.random() <= 0.5: do_flip = True img = fliplr(img) if self.calc_seg: seg = fliplr(seg) pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) c[0] = img.size(2) - c[0] # Color img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) # Prepare image and groundtruth map inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) if self.calc_seg: seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) # Generate ground truth tpts = pts.clone() target_weight = tpts[:, 2].clone().view(nparts, 1) target = torch.zeros(nparts, self.out_res, self.out_res) for i in range(nparts): # if tpts[i, 2] > 0: # This is evil!! if tpts[i, 1] > 0: tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) target[i], vis = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) target_weight[i, 0] *= vis # NEW: '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' # --- Meta info this_breed = self.breed_dict[name.split('/')[0]] # 120 # add information about location within breed similarity matrix folder_name = name.split('/')[0] breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] abbrev = COMPLETE_ABBREV_DICT[breed_name] try: sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix except: # some breeds are not in the xlsx file sim_breed_index = -1 meta = {'index' : index, 'center' : c, 'scale' : s, 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 meta2 = {'index' : index, 'center' : c, 'scale' : s, 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'ind_dataset': 3} # return different things depending on dataset_mode if self.dataset_mode=='keyp_only': # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) return inp, target, meta elif self.dataset_mode=='keyp_and_seg': meta['silh'] = seg[0, :, :] meta['name'] = name return inp, target, meta elif self.dataset_mode=='keyp_and_seg_and_partseg': # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations meta2['silh'] = seg[0, :, :] meta2['name'] = name fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) meta2['body_part_matrix'] = fake_body_part_matrix return inp, target, meta2 elif self.dataset_mode=='complete': target_dict = meta target_dict['silh'] = seg[0, :, :] # NEW for silhouette loss target_dict['img_border_mask'] = img_border_mask target_dict['has_seg'] = True if target_dict['silh'].sum() < 1: if ((not self.is_train) and self.val_opt == 'test'): raise ValueError elif self.is_train: print('had to replace training image') replacement_index = max(0, index - 1) inp, target_dict = self.__getitem__(replacement_index) else: # There seem to be a few validation images without segmentation # which would lead to nan in iou calculation replacement_index = max(0, index - 1) inp, target_dict = self.__getitem__(replacement_index) return inp, target_dict else: print('sampling error') import pdb; pdb.set_trace() raise ValueError def __len__(self): if self.is_train: return len(self.train_name_list) else: return len(self.test_name_list)