Spaces:
Runtime error
Runtime error
# 24 joints instead of 20!! | |
import gzip | |
import json | |
import os | |
import random | |
import math | |
import numpy as np | |
import torch | |
import torch.utils.data as data | |
from importlib_resources import open_binary | |
from scipy.io import loadmat | |
from tabulate import tabulate | |
import itertools | |
import json | |
from scipy import ndimage | |
from csv import DictReader | |
from pycocotools.mask import decode as decode_RLE | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | |
from configs.data_info import COMPLETE_DATA_INFO_24 | |
from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps | |
from stacked_hourglass.utils.misc import to_torch | |
from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform | |
import stacked_hourglass.datasets.utils_stanext as utils_stanext | |
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints | |
from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES | |
from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR | |
class StanExt(data.Dataset): | |
DATA_INFO = COMPLETE_DATA_INFO_24 | |
# Suggested joints to use for keypoint reprojection error calculations | |
ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] | |
def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, | |
scale_factor=0.25, rot_factor=30, label_type='Gaussian', | |
do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'): | |
self.V12 = V12 | |
self.is_train = is_train # training set or test set | |
if do_augment == 'yes': | |
self.do_augment = True | |
elif do_augment == 'no': | |
self.do_augment = False | |
elif do_augment=='default': | |
if self.is_train: | |
self.do_augment = True | |
else: | |
self.do_augment = False | |
else: | |
raise ValueError | |
self.inp_res = inp_res | |
self.out_res = out_res | |
self.sigma = sigma | |
self.scale_factor = scale_factor | |
self.rot_factor = rot_factor | |
self.label_type = label_type | |
self.dataset_mode = dataset_mode | |
if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': | |
self.calc_seg = True | |
else: | |
self.calc_seg = False | |
self.val_opt = val_opt | |
# create train/val split | |
self.img_folder = utils_stanext.get_img_dir(V12=self.V12) | |
self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) | |
self.train_name_list = list(self.train_dict.keys()) # 7004 | |
if self.val_opt == 'test': | |
self.test_dict = init_test_dict | |
self.test_name_list = list(self.test_dict.keys()) | |
elif self.val_opt == 'val': | |
self.test_dict = init_val_dict | |
self.test_name_list = list(self.test_dict.keys()) | |
else: | |
raise NotImplementedError | |
# stanext breed dict (contains for each name a stanext specific index) | |
breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') | |
self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) | |
self.train_name_list = sorted(self.train_name_list) | |
self.test_name_list = sorted(self.test_name_list) | |
random.seed(4) | |
random.shuffle(self.train_name_list) | |
random.shuffle(self.test_name_list) | |
if shorten_dataset_to is not None: | |
# sometimes it is useful to have a smaller set (validation speed, debugging) | |
self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] | |
self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] | |
# special case for debugging: 12 similar images | |
if shorten_dataset_to == 12: | |
my_sample = self.test_name_list[2] | |
for ind in range(0, 12): | |
self.test_name_list[ind] = my_sample | |
print('len(dataset): ' + str(self.__len__())) | |
# add results for eyes, whithers and throat as obtained through anipose -> they are used | |
# as pseudo ground truth at training time. | |
self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') | |
def get_data_sampler_info(self): | |
# for custom data sampler | |
if self.is_train: | |
name_list = self.train_name_list | |
else: | |
name_list = self.test_name_list | |
info_dict = {'name_list': name_list, | |
'stanext_breed_dict': self.breed_dict, | |
'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, | |
'breeds_summary': COMPLETE_SUMMARY_BREEDS, | |
'breeds_sim_martix_raw': SIM_MATRIX_RAW, | |
'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES | |
} | |
return info_dict | |
def get_breed_dict(self, breed_json_path, create_new_breed_json=False): | |
if create_new_breed_json: | |
breed_dict = {} | |
breed_index = 0 | |
for img_name in self.train_name_list: | |
folder_name = img_name.split('/')[0] | |
breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] | |
if not (folder_name in breed_dict): | |
breed_dict[folder_name] = { | |
'breed_name': breed_name, | |
'index': breed_index} | |
breed_index += 1 | |
with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) | |
else: | |
with open(breed_json_path) as json_file: breed_dict = json.load(json_file) | |
return breed_dict | |
def __getitem__(self, index): | |
if self.is_train: | |
name = self.train_name_list[index] | |
data = self.train_dict[name] | |
else: | |
name = self.test_name_list[index] | |
data = self.test_dict[name] | |
sf = self.scale_factor | |
rf = self.rot_factor | |
img_path = os.path.join(self.img_folder, data['img_path']) | |
try: | |
anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) | |
with open(anipose_res_path) as f: anipose_data = json.load(f) | |
anipose_thr = 0.2 | |
anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) | |
anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] | |
# anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 | |
anipose_joints_0to24_scores[anipose_joints_0to24_scores<anipose_thr] = 0.0 | |
anipose_joints_0to24[:, 2] = anipose_joints_0to24_scores | |
except: | |
# REMARK: This happens sometimes!!! maybe once every 10th image..? | |
# print('no anipose eye keypoints!') | |
anipose_joints_0to24 = np.zeros((24, 3)) | |
joints = np.concatenate((np.asarray(data['joints'])[:20, :], anipose_joints_0to24[20:24, :]), axis=0) | |
joints[joints[:, 2]==0, :2] = 0 # avoid nan values | |
pts = torch.Tensor(joints) | |
# inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) | |
# sf = scale * 200.0 / res[0] # res[0]=256 | |
# center = center * 1.0 / sf | |
# scale = scale / sf = 256 / 200 | |
# h = 200 * scale | |
bbox_xywh = data['img_bbox'] | |
bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] | |
bbox_max = max(bbox_xywh[2], bbox_xywh[3]) | |
bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) | |
# bbox_s = bbox_max / 200. # the dog will fill the image -> bbox_max = 256 | |
# bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 | |
bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 | |
c = torch.Tensor(bbox_c) | |
s = bbox_s | |
# For single-person pose estimation with a centered/scaled figure | |
nparts = pts.size(0) | |
img = load_image(img_path) # CxHxW | |
# segmentation map (we reshape it to 3xHxW, such that we can do the | |
# same transformations as with the image) | |
if self.calc_seg: | |
seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) | |
seg = torch.cat(3*[seg]) | |
r = 0 | |
do_flip = False | |
if self.do_augment: | |
s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] | |
r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 | |
# Flip | |
if random.random() <= 0.5: | |
do_flip = True | |
img = fliplr(img) | |
if self.calc_seg: | |
seg = fliplr(seg) | |
pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) | |
c[0] = img.size(2) - c[0] | |
# Color | |
img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) | |
img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) | |
img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) | |
# Prepare image and groundtruth map | |
inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) | |
img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground | |
inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) | |
if self.calc_seg: | |
seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) | |
# Generate ground truth | |
tpts = pts.clone() | |
target_weight = tpts[:, 2].clone().view(nparts, 1) | |
target = torch.zeros(nparts, self.out_res, self.out_res) | |
for i in range(nparts): | |
# if tpts[i, 2] > 0: # This is evil!! | |
if tpts[i, 1] > 0: | |
tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) | |
target[i], vis = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) | |
target_weight[i, 0] *= vis | |
# NEW: | |
'''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) | |
target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new | |
target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' | |
# --- Meta info | |
this_breed = self.breed_dict[name.split('/')[0]] # 120 | |
# add information about location within breed similarity matrix | |
folder_name = name.split('/')[0] | |
breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] | |
abbrev = COMPLETE_ABBREV_DICT[breed_name] | |
try: | |
sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix | |
except: # some breeds are not in the xlsx file | |
sim_breed_index = -1 | |
meta = {'index' : index, 'center' : c, 'scale' : s, | |
'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, | |
'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, | |
'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 | |
meta2 = {'index' : index, 'center' : c, 'scale' : s, | |
'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, | |
'ind_dataset': 3} | |
# return different things depending on dataset_mode | |
if self.dataset_mode=='keyp_only': | |
# save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) | |
return inp, target, meta | |
elif self.dataset_mode=='keyp_and_seg': | |
meta['silh'] = seg[0, :, :] | |
meta['name'] = name | |
return inp, target, meta | |
elif self.dataset_mode=='keyp_and_seg_and_partseg': | |
# partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations | |
meta2['silh'] = seg[0, :, :] | |
meta2['name'] = name | |
fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) | |
meta2['body_part_matrix'] = fake_body_part_matrix | |
return inp, target, meta2 | |
elif self.dataset_mode=='complete': | |
target_dict = meta | |
target_dict['silh'] = seg[0, :, :] | |
# NEW for silhouette loss | |
target_dict['img_border_mask'] = img_border_mask | |
target_dict['has_seg'] = True | |
if target_dict['silh'].sum() < 1: | |
if ((not self.is_train) and self.val_opt == 'test'): | |
raise ValueError | |
elif self.is_train: | |
print('had to replace training image') | |
replacement_index = max(0, index - 1) | |
inp, target_dict = self.__getitem__(replacement_index) | |
else: | |
# There seem to be a few validation images without segmentation | |
# which would lead to nan in iou calculation | |
replacement_index = max(0, index - 1) | |
inp, target_dict = self.__getitem__(replacement_index) | |
return inp, target_dict | |
else: | |
print('sampling error') | |
import pdb; pdb.set_trace() | |
raise ValueError | |
def __len__(self): | |
if self.is_train: | |
return len(self.train_name_list) | |
else: | |
return len(self.test_name_list) | |