# 24 joints instead of 20!! import gzip import json import os import random import math import numpy as np import torch import torch.utils.data as data from importlib_resources import open_binary from scipy.io import loadmat from tabulate import tabulate import itertools import json from scipy import ndimage import csv import pickle as pkl from csv import DictReader from pycocotools.mask import decode as decode_RLE import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from configs.data_info import COMPLETE_DATA_INFO_24 from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps from stacked_hourglass.utils.misc import to_torch from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform import stacked_hourglass.datasets.utils_stanext as utils_stanext from stacked_hourglass.utils.visualization import save_input_image_with_keypoints from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR from smal_pytorch.smal_model.smal_basics import get_symmetry_indices def read_csv(csv_file): with open(csv_file,'r') as f: reader = csv.reader(f) headers = next(reader) row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] return row_list class StanExtGC(data.Dataset): DATA_INFO = COMPLETE_DATA_INFO_24 # Suggested joints to use for keypoint reprojection error calculations ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, scale_factor=0.25, rot_factor=30, label_type='Gaussian', do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'): self.V12 = V12 self.is_train = is_train # training set or test set if do_augment == 'yes': self.do_augment = True elif do_augment == 'no': self.do_augment = False elif do_augment=='default': if self.is_train: self.do_augment = True else: self.do_augment = False else: raise ValueError self.inp_res = inp_res self.out_res = out_res self.sigma = sigma self.scale_factor = scale_factor self.rot_factor = rot_factor self.label_type = label_type self.dataset_mode = dataset_mode if self.dataset_mode=='complete' or self.dataset_mode=='complete_with_gc' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': self.calc_seg = True else: self.calc_seg = False self.val_opt = val_opt # create train/val split self.img_folder = utils_stanext.get_img_dir(V12=self.V12) self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) self.train_name_list = list(self.train_dict.keys()) # 7004 if self.val_opt == 'test': self.test_dict = init_test_dict self.test_name_list = list(self.test_dict.keys()) elif self.val_opt == 'val': self.test_dict = init_val_dict self.test_name_list = list(self.test_dict.keys()) else: raise NotImplementedError # path_gc_annots_overview = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/gc_annots_overview_first699.pkl' path_gc_annots_overview = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/gc_annots_overview_stage3complete.pkl' with open(path_gc_annots_overview, 'rb') as f: self.gc_annots_overview = pkl.load(f) list_gc_labelled_images = list(self.gc_annots_overview.keys()) test_name_list_gc = [] for name in self.test_name_list: if name.split('.')[0] in list_gc_labelled_images: test_name_list_gc.append(name) train_name_list_gc = [] for name in self.train_name_list: if name.split('.')[0] in list_gc_labelled_images: train_name_list_gc.append(name) self.test_name_list = test_name_list_gc self.train_name_list = train_name_list_gc random.seed(4) random.shuffle(self.test_name_list) ''' already_labelled = ['n02093991-Irish_terrier/n02093991_2874.jpg', 'n02093754-Border_terrier/n02093754_1062.jpg', 'n02092339-Weimaraner/n02092339_1672.jpg', 'n02096177-cairn/n02096177_4916.jpg', 'n02110185-Siberian_husky/n02110185_725.jpg', 'n02110806-basenji/n02110806_761.jpg', 'n02094433-Yorkshire_terrier/n02094433_2474.jpg', 'n02097474-Tibetan_terrier/n02097474_8796.jpg', 'n02099601-golden_retriever/n02099601_2495.jpg'] self.trainvaltest_dict = dict(self.train_dict) for d in (init_test_dict, init_val_dict): self.trainvaltest_dict.update(d) gc_annot_csv = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/my_gcannotations_qualification.csv' gc_row_list = read_csv(gc_annot_csv) json_acceptable_string = (gc_row_list[0]['vertices']).replace("'", "\"") self.gc_dict = json.loads(json_acceptable_string) self.train_name_list = already_labelled self.test_name_list = already_labelled ''' # stanext breed dict (contains for each name a stanext specific index) breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) # load smal symmetry info self.sym_ids_dict = get_symmetry_indices() ''' self.train_name_list = sorted(self.train_name_list) self.test_name_list = sorted(self.test_name_list) random.seed(4) random.shuffle(self.train_name_list) random.shuffle(self.test_name_list) if shorten_dataset_to is not None: # sometimes it is useful to have a smaller set (validation speed, debugging) self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] # special case for debugging: 12 similar images if shorten_dataset_to == 12: my_sample = self.test_name_list[2] for ind in range(0, 12): self.test_name_list[ind] = my_sample ''' print('len(dataset): ' + str(self.__len__())) # add results for eyes, whithers and throat as obtained through anipose -> they are used # as pseudo ground truth at training time. # self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v1_results_on_StanExt') # this is from hg_anipose_after01bugfix_v1 # self.prepare_anipose_res_and_save() def get_data_sampler_info(self): # for custom data sampler if self.is_train: name_list = self.train_name_list else: name_list = self.test_name_list info_dict = {'name_list': name_list, 'stanext_breed_dict': self.breed_dict, 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, 'breeds_summary': COMPLETE_SUMMARY_BREEDS, 'breeds_sim_martix_raw': SIM_MATRIX_RAW, 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES } return info_dict def get_breed_dict(self, breed_json_path, create_new_breed_json=False): if create_new_breed_json: breed_dict = {} breed_index = 0 for img_name in self.train_name_list: folder_name = img_name.split('/')[0] breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] if not (folder_name in breed_dict): breed_dict[folder_name] = { 'breed_name': breed_name, 'index': breed_index} breed_index += 1 with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) else: with open(breed_json_path) as json_file: breed_dict = json.load(json_file) return breed_dict def prepare_anipose_res_and_save(self): # I only had to run this once ... # path_animalpose_res_root = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/results/animalpose_hg8_v0/' path_animalpose_res_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) train_name_list = list(train_dict.keys()) val_name_list = list(init_val_dict.keys()) test_name_list = list(init_test_dict.keys()) all_dicts = [train_dict, init_val_dict, init_test_dict] all_name_lists = [train_name_list, val_name_list, test_name_list] all_prefixes = ['train', 'val', 'test'] for ind in range(3): this_name_list = all_name_lists[ind] this_dict = all_dicts[ind] this_prefix = all_prefixes[ind] for index in range(0, len(this_name_list)): print(index) name = this_name_list[index] data = this_dict[name] img_path = os.path.join(self.img_folder, data['img_path']) path_animalpose_res = os.path.join(path_animalpose_res_root.replace('XXX', this_prefix), data['img_path'].replace('.jpg', '.json')) # prepare predicted keypoints '''if is_train: path_animalpose_res = os.path.join(path_animalpose_res_root, 'train_stanext', 'res_' + str(index) + '.json') else: path_animalpose_res = os.path.join(path_animalpose_res_root, 'test_stanext', 'res_' + str(index) + '.json') ''' with open(path_animalpose_res) as f: animalpose_data = json.load(f) anipose_joints_256 = np.asarray(animalpose_data['pred_joints_256']).reshape((-1, 3)) anipose_center = animalpose_data['center'] anipose_scale = animalpose_data['scale'] anipose_joints_64 = anipose_joints_256 / 4 '''thrs_21to24 = 0.2 anipose_joints_21to24 = np.zeros((4, 3))) for ind_j in range(0:4): anipose_joints_untrans = transform(anipose_joints_64[20+ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 anipose_joints_trans_again = transform(anipose_joints_untrans+1, anipose_center, anipose_scale, [64, 64], invert=False, rot=0, as_int=False) anipose_joints_21to24[ind_j, :2] = anipose_joints_untrans if anipose_joints_256[20+ind_j, 2] >= thrs_21to24: anipose_joints_21to24[ind_j, 2] = 1''' anipose_joints_0to24 = np.zeros((24, 3)) for ind_j in range(24): # anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2]+1, anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 anipose_joints_0to24[ind_j, :2] = anipose_joints_untrans anipose_joints_0to24[ind_j, 2] = anipose_joints_256[ind_j, 2] # save anipose result for usage later on out_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) if not os.path.exists(os.path.dirname(out_path)): os.makedirs(os.path.dirname(out_path)) out_dict = {'orig_anipose_joints_256': list(anipose_joints_256.reshape((-1))), 'anipose_joints_0to24': list(anipose_joints_0to24[:, :3].reshape((-1))), 'orig_index': index, 'orig_scale': animalpose_data['scale'], 'orig_center': animalpose_data['center'], 'data_split': this_prefix, # 'is_train': is_train, } with open(out_path, 'w') as outfile: json.dump(out_dict, outfile) return def __getitem__(self, index): if self.is_train: train_val_test_Prefix = 'train' name = self.train_name_list[index] data = self.train_dict[name] else: train_val_test_Prefix = self.val_opt # 'val' or 'test' name = self.test_name_list[index] data = self.test_dict[name] img_path = os.path.join(self.img_folder, data['img_path']) ''' # for debugging only train_val_test_Prefix = 'train' name = self.train_name_list[index] data = self.trainvaltest_dict[name] img_path = os.path.join(self.img_folder, data['img_path']) if self.dataset_mode=='complete_with_gc': n_verts_smal = 3889 gc_info_raw = self.gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact gc_info = [] gc_info_tch = torch.zeros((n_verts_smal)) for ind_v in gc_info_raw: if ind_v < n_verts_smal: gc_info.append(ind_v) gc_info_tch[ind_v] = 1 gc_info_available = True ''' # array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist] gc_vertdists_overview = self.gc_annots_overview[name.split('.')[0]]['gc_vertdists_overview'] gc_info_tch = torch.tensor(gc_vertdists_overview[:, :]) # torch.tensor(gc_vertdists_overview[:, 0]) gc_info_available = True # import pdb; pdb.set_trace() debugging = False if debugging: import shutil import trimesh from smal_pytorch.smal_model.smal_torch_new import SMAL smal = SMAL() verts = smal.v_template.detach().cpu().numpy() faces = smal.faces.detach().cpu().numpy() vert_colors = np.repeat(255*gc_info_tch[:, 0].detach().cpu().numpy()[:, None], 3, 1) # vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) my_mesh.visual.vertex_colors = vert_colors debug_folder = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/gc_debugging/' my_mesh.export(debug_folder + (name.split('/')[1]).replace('.jpg', '_withgc.obj')) shutil.copy(img_path, debug_folder + name.split('/')[1]) sf = self.scale_factor rf = self.rot_factor try: # import pdb; pdb.set_trace() '''new_anipose_root_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' adjusted_new_anipose_root_path = new_anipose_root_path.replace('XXX', train_val_test_Prefix) new_anipose_res_path = adjusted_new_anipose_root_path + data['img_path'].replace('.jpg', '.json') with open(new_anipose_res_path) as f: new_anipose_data = json.load(f) ''' anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) with open(anipose_res_path) as f: anipose_data = json.load(f) anipose_thr = 0.2 anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256 # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 c = torch.Tensor(bbox_c) s = bbox_s # For single-person pose estimation with a centered/scaled figure nparts = pts.size(0) img = load_image(img_path) # CxHxW # segmentation map (we reshape it to 3xHxW, such that we can do the # same transformations as with the image) if self.calc_seg: seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) seg = torch.cat(3*[seg]) r = 0 do_flip = False if self.do_augment: s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 # Flip if random.random() <= 0.5: do_flip = True img = fliplr(img) if self.calc_seg: seg = fliplr(seg) pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) c[0] = img.size(2) - c[0] # flip ground contact annotations gc_info_tch_swapped = torch.zeros_like(gc_info_tch) gc_info_tch_swapped[self.sym_ids_dict['center'], :] = gc_info_tch[self.sym_ids_dict['center'], :] gc_info_tch_swapped[self.sym_ids_dict['right'], :] = gc_info_tch[self.sym_ids_dict['left'], :] gc_info_tch_swapped[self.sym_ids_dict['left'], :] = gc_info_tch[self.sym_ids_dict['right'], :] gc_info_tch = gc_info_tch_swapped # Color img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) # import pdb; pdb.set_trace() debugging = False if debugging and do_flip: import shutil import trimesh from smal_pytorch.smal_model.smal_torch_new import SMAL smal = SMAL() verts = smal.v_template.detach().cpu().numpy() faces = smal.faces.detach().cpu().numpy() vert_colors = np.repeat(255*gc_info_tch[:, 0].detach().cpu().numpy()[:, None], 3, 1) # vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) my_mesh.visual.vertex_colors = vert_colors debug_folder = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/gc_debugging/' my_mesh.export(debug_folder + (name.split('/')[1]).replace('.jpg', '_withgc_flip.obj')) # Prepare image and groundtruth map inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) if self.calc_seg: seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) # Generate ground truth tpts = pts.clone() target_weight = tpts[:, 2].clone().view(nparts, 1) target = torch.zeros(nparts, self.out_res, self.out_res) for i in range(nparts): # if tpts[i, 2] > 0: # This is evil!! if tpts[i, 1] > 0: tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) - 1 target[i], vis = draw_labelmap(target[i], tpts[i], self.sigma, type=self.label_type) target_weight[i, 0] *= vis # NEW: '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' # --- Meta info this_breed = self.breed_dict[name.split('/')[0]] # 120 # add information about location within breed similarity matrix folder_name = name.split('/')[0] breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] abbrev = COMPLETE_ABBREV_DICT[breed_name] try: sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix except: # some breeds are not in the xlsx file sim_breed_index = -1 meta = {'index' : index, 'center' : c, 'scale' : s, 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 meta2 = {'index' : index, 'center' : c, 'scale' : s, 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'ind_dataset': 3} # import pdb; pdb.set_trace() # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/old_animalpose_version/' # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/v0/' # save_input_image_with_keypoints(inp, meta['tpts'], out_path = out_path_root + name.replace('/', '_'), ratio_in_out=self.inp_res/self.out_res) # return different things depending on dataset_mode if self.dataset_mode=='keyp_only': # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) return inp, target, meta elif self.dataset_mode=='keyp_and_seg': meta['silh'] = seg[0, :, :] meta['name'] = name return inp, target, meta elif self.dataset_mode=='keyp_and_seg_and_partseg': # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations meta2['silh'] = seg[0, :, :] meta2['name'] = name fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) meta2['body_part_matrix'] = fake_body_part_matrix return inp, target, meta2 elif (self.dataset_mode=='complete') or (self.dataset_mode=='complete_with_gc'): target_dict = meta target_dict['silh'] = seg[0, :, :] # NEW for silhouette loss target_dict['img_border_mask'] = img_border_mask target_dict['has_seg'] = True # ground contact if self.dataset_mode=='complete_with_gc': target_dict['has_gc_is_touching'] = True target_dict['has_gc'] = gc_info_available target_dict['gc'] = gc_info_tch if target_dict['silh'].sum() < 1: if ((not self.is_train) and self.val_opt == 'test'): raise ValueError elif self.is_train: print('had to replace training image') replacement_index = max(0, index - 1) inp, target_dict = self.__getitem__(replacement_index) else: # There seem to be a few validation images without segmentation # which would lead to nan in iou calculation replacement_index = max(0, index - 1) inp, target_dict = self.__getitem__(replacement_index) return inp, target_dict else: print('sampling error') import pdb; pdb.set_trace() raise ValueError def __len__(self): if self.is_train: return len(self.train_name_list) else: return len(self.test_name_list)