diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c57d39753a1eadba7ee703bcfc4042c43cfdc9b4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,60 @@ +License +Software Copyright License for non-commercial scientific research purposes +Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use BITE data, model and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License + +Ownership / Licensees +The Software and the associated materials has been developed at the + +Max Planck Institute for Intelligent Systems +and +ETH Zurich + +Any copyright or patent right is owned by and proprietary material of the + +Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) + +hereinafter the “Licensor”. + +License Grant +Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: + +To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization; +To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects; +Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. + +The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it. + +No Distribution +The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. + +Disclaimer of Representations and Warranties +You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party. + +Limitation of Liability +Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. +Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. +Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders. +The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause. + +No Maintenance Services +You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time. + +Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. + +Publications using the Data & Software +You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software. + +Citation: + + +@inproceedings{BITE:2023, + title = {BITE: Beyond priors for Improved Three-D dog pose Estimation}, + author = {Rueegg, Nadine and Tripathi, Shashank and Schindler, Konrad and Black, Michael J. and Zuffi, Silvia}, + booktitle = {under review}, + year = {2023} + url = {https://bite.is.tue.mpg.de} +} +Commercial licensing opportunities +For commercial uses of the Data & Software, please send email to ps-license@tue.mpg.de + +This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. diff --git a/README.md b/README.md index 1866f2d33f6cf1e7bd2803e7a8a964dc0106824f..23b5aaf500d0613dffe2dbd67b44d8e7da6e3395 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,9 @@ ---- -title: Bite Gradio -emoji: 👀 -colorFrom: blue -colorTo: pink +title: BITE +emoji: 🐩 🐶 🐕 +colorFrom: pink +colorTo: green sdk: gradio -sdk_version: 3.35.2 -app_file: app.py +sdk_version: 3.0.2 +app_file: ./scripts/gradio_demo.py pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +python_version: 3.7.6 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..957cfbfef10cb618b24ca2fb74ddb49994bf45cf --- /dev/null +++ b/packages.txt @@ -0,0 +1,8 @@ +libgl1 +unzip +ffmpeg +libsm6 +libxext6 +libgl1-mesa-dri +libegl1-mesa +libgbm1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a76fa4a87ec0c42cfc7a8397641986a96e3054f2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +torch==1.6.0 +torchvision==0.7.0 +pytorch3d==0.2.5 +kornia==0.4.0 +matplotlib +opencv-python +trimesh +scipy +chumpy +pymp +importlib-resources +pycocotools +openpyxl +dominate +git+https://github.com/runa91/FrEIA.git diff --git a/scripts/gradio_demo.py b/scripts/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1ff5278bdc39c0fa795ff32bb7836b6324946e --- /dev/null +++ b/scripts/gradio_demo.py @@ -0,0 +1,672 @@ + +# aenv_new_icon_2 + +# was used for ttoptv6_sketchfab_v16: python src/test_time_optimization/ttopt_fromref_v6_sketchfab.py --workers 12 --save-images True --config refinement_cfg_visualization_withgc_withvertexwisegc_isflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --sketchfab 1 + +# for stanext images: +# python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar -s ttopt_vtest1 +# for all images from the folder datasets/test_image_crops: +# python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar -s ttopt_vtest2 + +'''import os +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"]="0" +try: + # os.system("pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html") + os.system("pip install --upgrade torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/cu101/torch_stable.html") +except Exception as e: + print(e)''' + +import argparse +import os.path +import json +import numpy as np +import pickle as pkl +import csv +from distutils.util import strtobool +import torch +from torch import nn +import torch.backends.cudnn +from torch.nn import DataParallel +from torch.utils.data import DataLoader +from collections import OrderedDict +import glob +from tqdm import tqdm +from dominate import document +from dominate.tags import * +from PIL import Image +from matplotlib import pyplot as plt +import trimesh +import cv2 +import shutil +import random +import gradio as gr + +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +import torchvision.transforms as T +from pytorch3d.structures import Meshes +from pytorch3d.loss import mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency + + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from combined_model.train_main_image_to_3d_wbr_withref import do_validation_epoch +from combined_model.model_shape_v7_withref_withgraphcnn import ModelImageTo3d_withshape_withproj + +from configs.barc_cfg_defaults import get_cfg_defaults, update_cfg_global_with_yaml, get_cfg_global_updated + +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d +from stacked_hourglass.datasets.utils_dataset_selection import get_evaluation_dataset, get_sketchfab_evaluation_dataset, get_crop_evaluation_dataset, get_norm_dict, get_single_crop_dataset_from_image + +from test_time_optimization.bite_inference_model_for_ttopt import BITEInferenceModel +from smal_pytorch.smal_model.smal_torch_new import SMAL +from configs.SMAL_configs import SMAL_MODEL_CONFIG +from smal_pytorch.renderer.differentiable_renderer import SilhRenderer +from test_time_optimization.utils.utils_ttopt import reset_loss_values, get_optimed_pose_with_glob + +from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error +from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch +from combined_model.loss_utils.loss_arap import Arap_Loss +from combined_model.loss_utils.loss_laplacian_mesh_comparison import LaplacianCTF # (coarse to fine animal) +from graph_networks import graphcmr # .utils_mesh import Mesh +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image + +random.seed(0) + +print( + "torch: ", torch.__version__, + "\ntorchvision: ", torchvision.__version__, +) + + +def get_prediction(model, img_path_or_img, confidence=0.5): + """ + see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g + get_prediction + parameters: + - img_path - path of the input image + - confidence - threshold value for prediction score + method: + - Image is obtained from the image path + - the image is converted to image tensor using PyTorch's Transforms + - image is passed through the model to get the predictions + - class, box coordinates are obtained, but only prediction score > threshold + are chosen. + """ + if isinstance(img_path_or_img, str): + img = Image.open(img_path_or_img).convert('RGB') + else: + img = img_path_or_img + transform = T.Compose([T.ToTensor()]) + img = transform(img) + pred = model([img]) + # pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())] + pred_class = list(pred[0]['labels'].numpy()) + pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())] + pred_score = list(pred[0]['scores'].detach().numpy()) + try: + pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1] + pred_boxes = pred_boxes[:pred_t+1] + pred_class = pred_class[:pred_t+1] + return pred_boxes, pred_class, pred_score + except: + print('no bounding box with a score that is high enough found! -> work on full image') + return None, None, None + + +def detect_object(model, img_path_or_img, confidence=0.5, rect_th=2, text_size=0.5, text_th=1): + """ + see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g + object_detection_api + parameters: + - img_path_or_img - path of the input image + - confidence - threshold value for prediction score + - rect_th - thickness of bounding box + - text_size - size of the class label text + - text_th - thichness of the text + method: + - prediction is obtained from get_prediction method + - for each prediction, bounding box is drawn and text is written + with opencv + - the final image is displayed + """ + boxes, pred_cls, pred_scores = get_prediction(model, img_path_or_img, confidence) + if isinstance(img_path_or_img, str): + img = cv2.imread(img_path_or_img) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + img = img_path_or_img + is_first = True + bbox = None + if boxes is not None: + for i in range(len(boxes)): + cls = pred_cls[i] + if cls == 18 and bbox is None: + cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th) + # cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) + # cv2.putText(img, str(pred_scores[i]), boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) + bbox = boxes[i] + return img, bbox + + +# -------------------------------------------------------------------------------------------------------------------- # +model_bbox = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) +model_bbox.eval() + +def run_bbox_inference(input_image): + # load configs + cfg = get_cfg_global_updated() + out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png') + img, bbox = detect_object(model=model_bbox, img_path_or_img=input_image, confidence=0.5) + fig = plt.figure() # plt.figure(figsize=(20,30)) + plt.imsave(out_path, img) + return img, bbox + + + +# -------------------------------------------------------------------------------------------------------------------- # +# python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar +args_config = "refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml" +args_model_file_complete = "cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar" +args_suffix = "ttopt_v0" +args_loss_weight_ttopt_path = "bite_loss_weights_ttopt.json" +args_workers = 12 +# -------------------------------------------------------------------------------------------------------------------- # + + + +# load configs +# step 1: load default configs +# step 2: load updates from .yaml file +path_config = os.path.join(get_cfg_defaults().barc_dir, 'src', 'configs', args_config) +update_cfg_global_with_yaml(path_config) +cfg = get_cfg_global_updated() + +# define path to load the trained model +path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args_model_file_complete) + +# define and create paths to save results +out_sub_name = cfg.data.VAL_OPT + '_' + cfg.data.DATASET + '_' + args_suffix + '/' +root_out_path = os.path.join(os.path.dirname(path_model_file_complete).replace(cfg.paths.ROOT_CHECKPOINT_PATH, cfg.paths.ROOT_OUT_PATH + 'results_gradio/'), out_sub_name) +root_out_path_details = root_out_path + 'details/' +if not os.path.exists(root_out_path): os.makedirs(root_out_path) +if not os.path.exists(root_out_path_details): os.makedirs(root_out_path_details) +print('root_out_path: ' + root_out_path) + +# other paths +root_data_path = os.path.join(os.path.dirname(__file__), '../', 'data') +# downsampling as used in graph neural network +root_smal_downsampling = os.path.join(root_data_path, 'graphcmr_data') +# remeshing as used for ground contact +remeshing_path = os.path.join(root_data_path, 'smal_data_remeshed', 'uniform_surface_sampling', 'my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl') + +loss_weight_path = os.path.join(os.path.dirname(__file__), '../', 'src', 'configs', 'ttopt_loss_weights', args_loss_weight_ttopt_path) +print(loss_weight_path) + + +# Select the hardware device to use for training. +if torch.cuda.is_available() and cfg.device=='cuda': + device = torch.device('cuda', torch.cuda.current_device()) + torch.backends.cudnn.benchmark = False # True +else: + device = torch.device('cpu') + +print('structure_pose_net: ' + cfg.params.STRUCTURE_POSE_NET) +print('refinement network type: ' + cfg.params.REF_NET_TYPE) +print('smal_model_type: ' + cfg.smal.SMAL_MODEL_TYPE) + +# prepare complete model +norm_dict = get_norm_dict(data_info=None, device=device) +bite_model = BITEInferenceModel(cfg, path_model_file_complete, norm_dict) +smal_model_type = bite_model.smal_model_type +logscale_part_list = SMAL_MODEL_CONFIG[smal_model_type]['logscale_part_list'] # ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] +smal = SMAL(smal_model_type=smal_model_type, template_name='neutral', logscale_part_list=logscale_part_list).to(device) +silh_renderer = SilhRenderer(image_size=256).to(device) + +# load loss modules -> not necessary! +# loss_module = Loss(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device) +# loss_module_ref = LossRef(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device) + +# remeshing utils +with open(remeshing_path, 'rb') as fp: + remeshing_dict = pkl.load(fp) +remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device) +remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device) + + + + +# create path for output files +save_imgs_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples') +if not os.path.exists(save_imgs_path): + os.makedirs(save_imgs_path) + + + + + +def run_bite_inference(input_image, bbox=None): + + with open(loss_weight_path, 'r') as j: + losses = json.loads(j.read()) + shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path)) + print(losses) + + # prepare dataset and dataset loader + val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox) + + # summarize information for normalization + norm_dict = get_norm_dict(stanext_data_info, device) + # get keypoint weights + keypoint_weights = torch.tensor(stanext_data_info.keypoint_weights, dtype=torch.float)[None, :].to(device) + + + # prepare progress bar + iterable = enumerate(val_loader) # the length of this iterator should be 1 + progress = None + if True: # not quiet: + progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False) + iterable = progress + ind_img_tot = 0 + + for i, (input, target_dict) in iterable: + batch_size = input.shape[0] + # prepare variables, put them on the right device + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key == 'has_seg': + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # get starting values for the optimization + preds_dict = bite_model.get_all_results(input) + # res_normal_and_ref = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['normal', 'ref']) + res = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['ref'])['ref'] + bs = res['pose_rotmat'].shape[0] + all_pose_6d = rotmat_to_rot6d(res['pose_rotmat'][:, None, 1:, :, :].clone().reshape((-1, 3, 3))).reshape((bs, -1, 6)) # [bs, 34, 6] + all_orient_6d = rotmat_to_rot6d(res['pose_rotmat'][:, None, :1, :, :].clone().reshape((-1, 3, 3))).reshape((bs, -1, 6)) # [bs, 1, 6] + + + ind_img = 0 + name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0] + + print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name) + ind_img_tot += 1 + batch_size = 1 + + # save initial visualizations + # save the image with keypoints as predicted by the stacked hourglass + pred_unp_prep = torch.cat((res['hg_keyp_256'][ind_img, :, :].detach(), res['hg_keyp_scores'][ind_img, :, :]), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + out_path = root_out_path + name + '_hg_key.png' + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.01, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save the input image + img_inp = input[ind_img, :, :, :].clone() + for t, m, s in zip(img_inp, stanext_data_info.rgb_mean, stanext_data_info.rgb_stddev): t.add_(m) # inverse to transforms.color_normalize() + img_inp = img_inp.detach().cpu().numpy().transpose(1, 2, 0) + img_init = Image.fromarray(np.uint8(255*img_inp)).convert('RGB') + img_init.save(root_out_path_details + name + '_img_ainit.png') + # save ground truth silhouette (for visualization only, it is not used during the optimization) + target_img_silh = Image.fromarray(np.uint8(255*target_dict['silh'][ind_img, :, :].detach().cpu().numpy())).convert('RGB') + target_img_silh.save(root_out_path_details + name + '_target_silh.png') + # save the silhouette as predicted by the stacked hourglass + hg_img_silh = Image.fromarray(np.uint8(255*res['hg_silh_prep'][ind_img, :, :].detach().cpu().numpy())).convert('RGB') + hg_img_silh.save(root_out_path + name + '_hg_silh.png') + + # initialize the variables over which we want to optimize + optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) + optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6] + optimed_betas = res['betas'][ind_img, None, :].to(device).clone().detach().requires_grad_(True) # [1,30] + optimed_trans_xy = res['trans'][ind_img, None, :2].to(device).clone().detach().requires_grad_(True) + optimed_trans_z =res['trans'][ind_img, None, 2:3].to(device).clone().detach().requires_grad_(True) + optimed_camera_flength = res['flength'][ind_img, None, :].to(device).clone().detach().requires_grad_(True) # [1,1] + n_vert_comp = 2*smal.n_center + 3*smal.n_left + optimed_vert_off_compact = torch.tensor(np.zeros((batch_size, n_vert_comp)), dtype=torch.float, + device=device, + requires_grad=True) + assert len(logscale_part_list) == 7 + new_betas_limb_lengths = res['betas_limbs'][ind_img, None, :] + optimed_betas_limbs = new_betas_limb_lengths.to(device).clone().detach().requires_grad_(True) # [1,7] + + # define the optimizers + optimizer = torch.optim.SGD( + # [optimed_pose, optimed_trans_xy, optimed_betas, optimed_betas_limbs, optimed_orient, optimed_vert_off_compact], + [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_pose_6d, optimed_orient_6d, optimed_betas, optimed_betas_limbs], + lr=5*1e-4, # 1e-3, + momentum=0.9) + optimizer_vshift = torch.optim.SGD( + [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_pose_6d, optimed_orient_6d, optimed_betas, optimed_betas_limbs, optimed_vert_off_compact], + lr=1e-4, # 1e-4, + momentum=0.9) + nopose_optimizer = torch.optim.SGD( + # [optimed_pose, optimed_trans_xy, optimed_betas, optimed_betas_limbs, optimed_orient, optimed_vert_off_compact], + [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_orient_6d, optimed_betas, optimed_betas_limbs], + lr=5*1e-4, # 1e-3, + momentum=0.9) + nopose_optimizer_vshift = torch.optim.SGD( + [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_orient_6d, optimed_betas, optimed_betas_limbs, optimed_vert_off_compact], + lr=1e-4, # 1e-4, + momentum=0.9) + # define schedulers + patience = 5 + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + verbose=0, + min_lr=1e-5, + patience=patience) + scheduler_vshift = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_vshift, + mode='min', + factor=0.5, + verbose=0, + min_lr=1e-5, + patience=patience) + + # set all loss values to 0 + losses = reset_loss_values(losses) + + # prepare all the target labels: keypoints, silhouette, ground contact, ... + with torch.no_grad(): + thr_kp = 0.2 + kp_weights = res['hg_keyp_scores'] + kp_weights[res['hg_keyp_scores']ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32)) + target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) + vert_colors = np.repeat(255*target_gc_class.detach().cpu().numpy()[0, :, None], 3, 1) + vert_colors[:, 2] = 255 + faces_prep = smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + # prepare target silhouette and keypoints, from stacked hourglass predictions + target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach() + target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach() + # find out if ground contact constraints should be used for the image at hand + # print('is flat: ' + str(res['isflat_prep'][ind_img])) + if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher + isflat = [True] + else: + isflat = [False] + if target_gc_class_remeshed_prep.sum() > 3: + istouching = [True] + else: + istouching = [False] + ignore_pose_optimization = False + + + ########################################################################################################## + # start optimizing for this image + n_iter = 301 # how many iterations are desired? (+1) + loop = tqdm(range(n_iter)) + per_loop_lst = [] + list_error_procrustes = [] + for i in loop: + # for the first 150 iterations steps we don't allow vertex shifts + if i == 0: + current_i = 0 + if ignore_pose_optimization: + current_optimizer = nopose_optimizer + else: + current_optimizer = optimizer + current_scheduler = scheduler + current_weight_name = 'weight' + # after 150 iteration steps we start with vertex shifts + elif i == 150: + current_i = 0 + if ignore_pose_optimization: + current_optimizer = nopose_optimizer_vshift + else: + current_optimizer = optimizer_vshift + current_scheduler = scheduler_vshift + current_weight_name = 'weight_vshift' + # set up arap loss + if losses["arap"]['weight_vshift'] > 0.0: + with torch.no_grad(): + torch_mesh_comparison = Meshes(smal_verts.detach(), faces_prep.detach()) + arap_loss = Arap_Loss(meshes=torch_mesh_comparison, device=device) + # is there a laplacian loss similar as in coarse-to-fine? + if losses["lapctf"]['weight_vshift'] > 0.0: + torch_verts_comparison = smal_verts.detach().clone() + smal_model_type_downsampling = '39dogs_norm' + smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type_downsampling]['smal_model_path']).replace('.pkl', '_template.npz') + smal_downsampling_npz_path = os.path.join(root_smal_downsampling, smal_downsampling_npz_name) + data = np.load(smal_downsampling_npz_path, encoding='latin1', allow_pickle=True) + adjmat = data['A'][0] + laplacian_ctf = LaplacianCTF(adjmat, device=device) + else: + pass + + + current_optimizer.zero_grad() + + # get 3d smal model + optimed_pose_with_glob = get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d) + optimed_trans = torch.cat((optimed_trans_xy, optimed_trans_z), dim=1) + smal_verts, keyp_3d, _ = smal(beta=optimed_betas, betas_limbs=optimed_betas_limbs, pose=optimed_pose_with_glob, vert_off_compact=optimed_vert_off_compact, trans=optimed_trans, keyp_conf='olive', get_skin=True) + + # render silhouette and keypoints + pred_silh_images, pred_keyp_raw = silh_renderer(vertices=smal_verts, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength) + pred_keyp = pred_keyp_raw[:, :24, :] + + # save silhouette reprojection visualization + if i==0: + img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB') + img_silh.save(root_out_path_details + name + '_silh_ainit.png') + my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True) + my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj') + + # silhouette loss + diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh) + losses['silhouette']['value'] = diff_silh.mean() + + # keypoint_loss + output_kp_resh = (pred_keyp[0, :, :]).reshape((-1, 2)) + losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt() * \ + weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ + max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) + # losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) + + # pose priors on refined pose + losses['pose_legs_side']['value'] = leg_sideway_error(optimed_pose_with_glob) + losses['pose_legs_tors']['value'] = leg_torsion_error(optimed_pose_with_glob) + losses['pose_tail_side']['value'] = tail_sideway_error(optimed_pose_with_glob) + losses['pose_tail_tors']['value'] = tail_torsion_error(optimed_pose_with_glob) + losses['pose_spine_side']['value'] = spine_sideway_error(optimed_pose_with_glob) + losses['pose_spine_tors']['value'] = spine_torsion_error(optimed_pose_with_glob) + + # ground contact loss + sel_verts = torch.index_select(smal_verts, dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((batch_size, remeshing_relevant_faces.shape[0], 3, 3)) + verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts) + + # gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching']) + gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, isflat, istouching) + + losses['gc_plane']['value'] = torch.mean(gc_errors_plane) + losses['gc_belowplane']['value'] = torch.mean(gc_errors_under_plane) + + # edge length of the predicted mesh + if (losses["edge"][current_weight_name] + losses["normal"][ current_weight_name] + losses["laplacian"][ current_weight_name]) > 0: + torch_mesh = Meshes(smal_verts, faces_prep.detach()) + losses["edge"]['value'] = mesh_edge_loss(torch_mesh) + # mesh normal consistency + losses["normal"]['value'] = mesh_normal_consistency(torch_mesh) + # mesh laplacian smoothing + losses["laplacian"]['value'] = mesh_laplacian_smoothing(torch_mesh, method="uniform") + + # arap loss + if losses["arap"][current_weight_name] > 0.0: + torch_mesh = Meshes(smal_verts, faces_prep.detach()) + losses["arap"]['value'] = arap_loss(torch_mesh) + + # laplacian loss for comparison (from coarse-to-fine paper) + if losses["lapctf"][current_weight_name] > 0.0: + verts_refine = smal_verts + loss_almost_arap, loss_smooth = laplacian_ctf(verts_refine, torch_verts_comparison) + losses["lapctf"]['value'] = loss_almost_arap + + # Weighted sum of the losses + total_loss = 0.0 + for k in ['keyp', 'silhouette', 'pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_tors', 'pose_spine_side', 'gc_plane', 'gc_belowplane', 'edge', 'normal', 'laplacian', 'arap', 'lapctf']: + if losses[k][current_weight_name] > 0.0: + total_loss += losses[k]['value'] * losses[k][current_weight_name] + + # calculate gradient and make optimization step + total_loss.backward(retain_graph=True) # + current_optimizer.step() + current_scheduler.step(total_loss) + loop.set_description(f"Body Fitting = {total_loss.item():.3f}") + + # save the result three times (0, 150, 300) + if i % 150 == 0: + # save silhouette image + img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB') + img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png') + # save image overlay + visualizations = silh_renderer.get_visualization_nograd(smal_verts, faces_prep, optimed_camera_flength, color=0) + pred_tex = visualizations[0, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + # out_path = root_out_path_details + name + '_tex_pred_e' + format(i, '03d') + '.png' + # plt.imsave(out_path, pred_tex) + input_image_np = img_inp.copy() + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + pred_tex_max = np.max(pred_tex, axis=2) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png' + plt.imsave(out_path, im_masked) + # save mesh + my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True) + my_mesh_tri.visual.vertex_colors = vert_colors + my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj') + # save focal length (together with the mesh this is enough to create an overlay in blender) + out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz' + np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy()) + current_i += 1 + + # prepare output mesh + mesh = my_mesh_tri # all_results[0]['mesh_posed'] + mesh.apply_transform([[-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z') + mesh.export(file_obj=result_path + '.glb') + result_gltf = result_path + '.glb' + return result_gltf + + + + + +# -------------------------------------------------------------------------------------------------------------------- # + + +def run_complete_inference(img_path_or_img, crop_choice): + # depending on crop_choice: run faster r-cnn or take the input image directly + if crop_choice == "input image is cropped": + if isinstance(img_path_or_img, str): + img = cv2.imread(img_path_or_img) + output_interm_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + output_interm_image = img_path_or_img + output_interm_bbox = None + else: + output_interm_image, output_interm_bbox = run_bbox_inference(img_path_or_img.copy()) + # run barc inference + result_gltf = run_bite_inference(img_path_or_img, output_interm_bbox) + # add white border to image for nicer alignment + output_interm_image_vis = np.concatenate((255*np.ones_like(output_interm_image), output_interm_image, 255*np.ones_like(output_interm_image)), axis=1) + return [result_gltf, result_gltf, output_interm_image_vis] + + + + +######################################################################################################################## + +# see: https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization/blob/main/PIFu/spaces.py + +description = ''' +# BITE + +#### Project Page +* https://bite.is.tue.mpg.de/ + +#### Description +This is a demo for BITE (*B*eyond Priors for *I*mproved *T*hree-{D} Dog Pose *E*stimation). +You can either submit a cropped image or choose the option to run a pretrained Faster R-CNN in order to obtain a bounding box. +Please have a look at the examples below. +
+ +More + +#### Citation + +``` +@inproceedings{bite2023rueegg, + title = {{BITE}: Beyond Priors for Improved Three-{D} Dog Pose Estimation}, + author = {R\"uegg, Nadine and Tripathi, Shashank and Schindler, Konrad and Black, Michael J. and Zuffi, Silvia}, + booktitle = {IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)}, + pages = {8867-8876}, + year = {2023}, +} +``` + +#### Image Sources +* Stanford extra image dataset +* Images from google search engine + * https://www.dogtrainingnation.com/wp-content/uploads/2015/02/keep-dog-training-sessions-short.jpg + * https://thumbs.dreamstime.com/b/hund-und-seine-neue-hundeh%C3%BCtte-36757551.jpg + * https://www.mydearwhippet.com/wp-content/uploads/2021/04/whippet-temperament-2.jpg + * https://media.istockphoto.com/photos/ibizan-hound-at-the-shore-in-winter-picture-id1092705644?k=20&m=1092705644&s=612x612&w=0&h=ppwg92s9jI8GWnk22SOR_DWWNP8b2IUmLXSQmVey5Ss= + + +
+''' + + + + + + +example_images = sorted(glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.jpg')) + glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.png'))) +random.shuffle(example_images) +# example_images.reverse() +# examples = [[img, "input image is cropped"] for img in example_images] +examples = [] +for img in example_images: + if os.path.basename(img)[:2] == 'z_': + examples.append([img, "use Faster R-CNN to get a bounding box"]) + else: + examples.append([img, "input image is cropped"]) + +demo = gr.Interface( + fn=run_complete_inference, + description=description, + # inputs=gr.Image(type="filepath", label="Input Image"), + inputs=[gr.Image(label="Input Image"), + gr.Radio(["input image is cropped", "use Faster R-CNN to get a bounding box"], value="use Faster R-CNN to get a bounding box", label="Crop Choice"), + ], + outputs=[ + gr.Model3D( + clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), + gr.File(label="Download 3D Model"), + gr.Image(label="Bounding Box (Faster R-CNN prediction)"), + + ], + examples=examples, + thumbnail="bite_thumbnail.png", + allow_flagging="never", + cache_examples=True, + examples_per_page=14, +) + +demo.launch(share=True) \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/bps_2d/bps_for_segmentation.py b/src/bps_2d/bps_for_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7382c5e875f878b296321fed6e0c46b037781e --- /dev/null +++ b/src/bps_2d/bps_for_segmentation.py @@ -0,0 +1,114 @@ + +# code idea from https://github.com/sergeyprokudin/bps + +import os +import numpy as np +from PIL import Image +import time +import scipy +import scipy.spatial +import pymp + + +##################### +QUERY_POINTS = np.asarray([30, 34, 31, 55, 29, 84, 35, 108, 34, 145, 29, 171, 27, + 196, 29, 228, 58, 35, 61, 55, 57, 83, 56, 109, 63, 148, 58, 164, 57, 197, 60, + 227, 81, 26, 87, 58, 85, 87, 89, 117, 86, 142, 89, 172, 84, 197, 88, 227, 113, + 32, 116, 58, 112, 88, 118, 113, 109, 147, 114, 173, 119, 201, 113, 229, 139, + 29, 141, 59, 142, 93, 139, 117, 146, 147, 141, 173, 142, 201, 143, 227, 170, + 26, 173, 59, 166, 90, 174, 117, 176, 141, 169, 175, 167, 198, 172, 227, 198, + 30, 195, 59, 204, 85, 198, 116, 195, 140, 198, 175, 194, 193, 199, 227, 221, + 26, 223, 57, 227, 83, 227, 113, 227, 140, 226, 173, 230, 196, 228, 229]).reshape((64, 2)) +##################### + +class SegBPS(): + + def __init__(self, query_points=QUERY_POINTS, size=256): + self.size = size + self.query_points = query_points + row, col = np.indices((self.size, self.size)) + self.indices_rc = np.stack((row, col), axis=2) # (256, 256, 2) + self.pts_aranged = np.arange(64) + return + + def _do_kdtree(self, combined_x_y_arrays, points): + # see https://stackoverflow.com/questions/10818546/finding-index-of-nearest- + # point-in-numpy-arrays-of-x-and-y-coordinates + mytree = scipy.spatial.cKDTree(combined_x_y_arrays) + dist, indexes = mytree.query(points) + return indexes + + def calculate_bps_points(self, seg, thr=0.5, vis=False, out_path=None): + # seg: input segmentation image of shape (256, 256) with values between 0 and 1 + query_val = seg[self.query_points[:, 0], self.query_points[:, 1]] + pts_fg = self.pts_aranged[query_val>=thr] + pts_bg = self.pts_aranged[query_val=thr] + if candidate_inds_bg.shape[0] == 0: + candidate_inds_bg = np.ones((1, 2)) * 128 # np.zeros((1, 2)) + if candidate_inds_fg.shape[0] == 0: + candidate_inds_fg = np.ones((1, 2)) * 128 # np.zeros((1, 2)) + # calculate nearest points + all_nearest_points = np.zeros((64, 2)) + all_nearest_points[pts_fg, :] = candidate_inds_bg[self._do_kdtree(candidate_inds_bg, self.query_points[pts_fg, :]), :] + all_nearest_points[pts_bg, :] = candidate_inds_fg[self._do_kdtree(candidate_inds_fg, self.query_points[pts_bg, :]), :] + all_nearest_points_01 = all_nearest_points / 255. + if vis: + self.visualize_result(seg, all_nearest_points, out_path=out_path) + return all_nearest_points_01 + + def calculate_bps_points_batch(self, seg_batch, thr=0.5, vis=False, out_path=None): + # seg_batch: input segmentation image of shape (bs, 256, 256) with values between 0 and 1 + bs = seg_batch.shape[0] + all_nearest_points_01_batch = np.zeros((bs, self.query_points.shape[0], 2)) + for ind in range(0, bs): # 0.25 + seg = seg_batch[ind, :, :] + all_nearest_points_01 = self.calculate_bps_points(seg, thr=thr, vis=vis, out_path=out_path) + all_nearest_points_01_batch[ind, :, :] = all_nearest_points_01 + return all_nearest_points_01_batch + + def visualize_result(self, seg, all_nearest_points, out_path=None): + import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + # img: (256, 256, 3) + img = (np.stack((seg, seg, seg), axis=2) * 155).astype(np.int) + if out_path is None: + ind_img = 0 + out_path = '../test_img' + str(ind_img) + '.png' + fig, ax = plt.subplots() + plt.imshow(img) + plt.gca().set_axis_off() + plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) + plt.margins(0,0) + ratio_in_out = 1 # 255 + for idx, (y, x) in enumerate(self.query_points): + x = int(x*ratio_in_out) + y = int(y*ratio_in_out) + plt.scatter([x], [y], marker="x", s=50) + x2 = int(all_nearest_points[idx, 1]) + y2 = int(all_nearest_points[idx, 0]) + plt.scatter([x2], [y2], marker="o", s=50) + plt.plot([x, x2], [y, y2]) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0) + plt.close() + return + + + + + +if __name__ == "__main__": + ind_img = 2 # 4 + path_seg_top = '...../pytorch-stacked-hourglass/results/dogs_hg8_ks_24_v1/test/' + path_seg = os.path.join(path_seg_top, 'seg_big_' + str(ind_img) + '.png') + img = np.asarray(Image.open(path_seg)) + # min is 0.004, max is 0.9 + # low values are background, high values are foreground + seg = img[:, :, 1] / 255. + # calculate points + bps = SegBPS() + bps.calculate_bps_points(seg, thr=0.5, vis=False, out_path=None) + + diff --git a/src/combined_model/__init__.py b/src/combined_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/combined_model/helper.py b/src/combined_model/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..99552a5714b39c4e1311ba84cd1ac40a5b6d34eb --- /dev/null +++ b/src/combined_model/helper.py @@ -0,0 +1,207 @@ + +import torch +import torch.nn as nn +import torch.backends.cudnn +import torch.nn.parallel +from tqdm import tqdm +import os +import pathlib +from matplotlib import pyplot as plt +import cv2 +import numpy as np +import torch +import trimesh + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image +from metrics.metrics import Metrics +from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS + + +# GOAL: have all the functions from the validation and visual epoch together + + +''' +save_imgs_path = ... +prefix = '' +input # this is the image +data_info +target_dict +render_all +model + + +vertices_smal = output_reproj['vertices_smal'] +flength = output_unnorm['flength'] +hg_keyp_norm = output['keypoints_norm'] +hg_keyp_scores = output['keypoints_scores'] +betas = output_reproj['betas'] +betas_limbs = output_reproj['betas_limbs'] +zz = output_reproj['z'] +pose_rotmat = output_unnorm['pose_rotmat'] +trans = output_unnorm['trans'] +pred_keyp = output_reproj['keyp_2d'] +pred_silh = output_reproj['silh'] +''' + +################################################# + +def eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=False): + device = input.device + curr_batch_size = input.shape[0] + # render predicted 3d models + visualizations = model.render_vis_nograd(vertices=vertices_smal, + focal_lengths=flength, + color=0) # color=2) + for ind_img in range(len(target_dict['index'])): + try: + # import pdb; pdb.set_trace() + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + # save image with predicted keypoints + out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' + pred_unp = (hg_keyp_norm[ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) + pred_unp_maxval = hg_keyp_scores[ind_img, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save predicted 3d model (front view) + pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + out_path = save_imgs_path + '/' + prefix + 'tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + out_path = save_imgs_path + '/' + prefix + 'comp_pred_' + img_name + '.png' + plt.imsave(out_path, im_masked) + # save predicted 3d model (side view) + vertices_cent = vertices_smal - vertices_smal.mean(dim=1)[:, None, :] + roll = np.pi / 2 * torch.ones(1).float().to(device) + pitch = np.pi / 2 * torch.ones(1).float().to(device) + tensor_0 = torch.zeros(1).float().to(device) + tensor_1 = torch.ones(1).float().to(device) + RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) + RY = torch.stack([ + torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + torch.stack([tensor_0, tensor_1, tensor_0]), + torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) + vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3)) + vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 + + visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, + focal_lengths=flength, + color=0) # 2) + pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + out_path = save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + if render_all: + # save input image + inp_img = input[ind_img, :, :, :].detach().clone() + out_path = save_imgs_path + '/image_' + img_name + '.png' + save_input_image(inp_img, out_path) + # save mesh + V_posed = vertices_smal[ind_img, :, :].detach().cpu().numpy() + Faces = model.smal.f + mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True) + mesh_posed.export(save_imgs_path + '/' + prefix + 'mesh_posed_' + img_name + '.obj') + except: + print('dont save an image') + +############ + +def eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=None, skip_pck_and_iou=False): + preds = {} + preds['betas'] = betas.cpu().detach().numpy() + preds['betas_limbs'] = betas_limbs.cpu().detach().numpy() + preds['z'] = zz.cpu().detach().numpy() + preds['pose_rotmat'] = pose_rotmat.cpu().detach().numpy() + preds['flength'] = flength.cpu().detach().numpy() + preds['trans'] = trans.cpu().detach().numpy() + preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1)) + img_names = [] + for ind_img2 in range(0, betas.shape[0]): + if test_name_list is not None: + img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_') + img_name2 = img_name2.split('.')[0] + else: + img_name2 = str(index) + '_' + str(ind_img2) + img_names.append(img_name2) + preds['image_names'] = img_names + if not skip_pck_and_iou: + # prepare keypoints for PCK calculation - predicted as well as ground truth + # pred_keyp = output_reproj['keyp_2d'] # 256 + gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1) + # gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1 + gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm + # prepare silhouette for IoU calculation - predicted as well as ground truth + has_seg = target_dict['has_seg'] + img_border_mask = target_dict['img_border_mask'][:, 0, :, :] + gtseg = target_dict['silh'] + synth_silhouettes = pred_silh[:, 0, :, :] # output_reproj['silh'] + synth_silhouettes[synth_silhouettes>0.5] = 1 + synth_silhouettes[synth_silhouettes<0.5] = 0 + # calculate PCK as well as IoU (similar to WLDO) + preds['acc_PCK'] = Metrics.PCK( + pred_keyp, gt_keypoints, + gtseg, has_seg, idxs=EVAL_KEYPOINTS, + thresh_range=[pck_thresh], # [0.15], + ) + preds['acc_IOU'] = Metrics.IOU( + synth_silhouettes, gtseg, + img_border_mask, mask=has_seg + ) + for group, group_kps in KEYPOINT_GROUPS.items(): + preds[f'{group}_PCK'] = Metrics.PCK( + pred_keyp, gt_keypoints, gtseg, has_seg, + thresh_range=[pck_thresh], # [0.15], + idxs=group_kps + ) + return preds + + +# preds['acc_PCK'] = Metrics.PCK(pred_keyp, gt_keypoints, gtseg, has_seg, idxs=EVAL_KEYPOINTS, thresh_range=[pck_thresh]) +# preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, gtseg, img_border_mask, mask=has_seg) +############################# + +def eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=False): + if not skip_pck_and_iou: + if not (preds['acc_PCK'].data.cpu().numpy().shape == (summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size]).shape): + import pdb; pdb.set_trace() + summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() + summary['acc_sil_2d'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() + for part in summary['pck_by_part']: + summary['pck_by_part'][part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() + summary['betas'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas'] + summary['betas_limbs'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs'] + summary['z'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z'] + summary['pose_rotmat'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat'] + summary['flength'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength'] + summary['trans'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans'] + summary['breed_indices'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index'] + summary['image_names'].extend(preds['image_names']) + return + + +def get_triangle_faces_from_pyvista_poly(poly): + """Fetch all triangle faces.""" + stream = poly.faces + tris = [] + i = 0 + while i < len(stream): + n = stream[i] + if n != 3: + i += n + 1 + continue + stop = i + n + 1 + tris.append(stream[i+1:stop]) + i = stop + return np.array(tris) \ No newline at end of file diff --git a/src/combined_model/helper3.py b/src/combined_model/helper3.py new file mode 100644 index 0000000000000000000000000000000000000000..7230caa0e0da72afd8b2271327bec453fe9f8f7b --- /dev/null +++ b/src/combined_model/helper3.py @@ -0,0 +1,17 @@ + +import numpy as np + +def get_triangle_faces_from_pyvista_poly(poly): + """Fetch all triangle faces.""" + stream = poly.faces + tris = [] + i = 0 + while i < len(stream): + n = stream[i] + if n != 3: + i += n + 1 + continue + stop = i + n + 1 + tris.append(stream[i+1:stop]) + i = stop + return np.array(tris) \ No newline at end of file diff --git a/src/combined_model/loss_image_to_3d_refinement.py b/src/combined_model/loss_image_to_3d_refinement.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3b85001ca7457afd5cfab639094de69b3203a6 --- /dev/null +++ b/src/combined_model/loss_image_to_3d_refinement.py @@ -0,0 +1,216 @@ + + +import torch +import numpy as np +import pickle as pkl + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) +# from priors.pose_prior_35 import Prior +# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior +from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior +from priors.shape_prior import ShapePrior +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa, geodesic_loss_R +from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error +from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch + +from priors.shape_prior import ShapePrior +from configs.SMAL_configs import SMAL_MODEL_CONFIG + +from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss + + +class LossRef(torch.nn.Module): + def __init__(self, smal_model_type, data_info, nf_version=None): + super(LossRef, self).__init__() + self.criterion_regr = torch.nn.MSELoss() # takes the mean + self.criterion_class = torch.nn.CrossEntropyLoss() + + class_weights_isflat = torch.tensor([12, 2]) + self.criterion_class_isflat = torch.nn.CrossEntropyLoss(weight=class_weights_isflat) + self.criterion_l1 = torch.nn.L1Loss() + self.geodesic_loss = geodesic_loss_R(reduction='mean') + self.gc_loss_on_mesh = LossGConMesh() + self.data_info = data_info + self.smal_model_type = smal_model_type + self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :]) + # if nf_version is not None: + # self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version) + + self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'] + self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov + + remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl' + with open(remeshing_path, 'rb') as fp: + self.remeshing_dict = pkl.load(fp) + self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long) + self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32) + + + + # load 3d data for the unity dogs (an optional shape prior for 11 breeds) + self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs'] + if self.unity_smal_shape_prior_dogs is not None: + self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type) + else: + self.dog_betas_unity = None + + + + + + + + def forward(self, output_ref, output_ref_comp, target_dict, weight_dict_ref): + # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image'] + # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight'] + batch_size = output_ref['keyp_2d'].shape[0] + loss_dict_temp = {} + + # loss on reprojected keypoints + output_kp_resh = (output_ref['keyp_2d']).reshape((-1, 2)) + target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2)) + weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) + keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1)) + loss_dict_temp['keyp_ref'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ + max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) + + # loss on reprojected silhouette + assert output_ref['silh'].shape == (target_dict['silh'][:, None, :, :]).shape + silh_loss_type = 'default' + if silh_loss_type == 'default': + with torch.no_grad(): + thr_silh = 20 + diff = torch.norm(output_kp_resh - target_kp_resh, dim=1) + diff_x = diff.reshape((batch_size, -1)) + weights_resh_x = weights_resh.reshape((batch_size, -1)) + unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6) + loss_silh_bs = ((output_ref['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_ref['silh'].shape[2]*output_ref['silh'].shape[3]) + loss_dict_temp['silh_ref'] = loss_silh_bs[unweighted_kp_mean_dist 0: + if keep_smal_mesh: + target_gc_class = target_dict['gc'][:, :, 0] + gc_errors_plane = calculate_plane_errors_batch(output_ref['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching']) + loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane) + else: # use a uniformly sampled mesh + target_gc_class = target_dict['gc'][:, :, 0] + device = output_ref['vertices_smal'].device + remeshing_relevant_faces = self.remeshing_relevant_faces.to(device) + remeshing_relevant_barys = self.remeshing_relevant_barys.to(device) + + bs = output_ref['vertices_smal'].shape[0] + # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_ref['vertices_smal'][:, self.remeshing_relevant_faces]) + # sel_verts_comparison = output_ref['vertices_smal'][:, self.remeshing_relevant_faces] + # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison) + sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3)) + verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts) + target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32)) + target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) + gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching']) + loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane) + loss_dict_temp['gc_blowplane'] = torch.mean(gc_errors_under_plane) + + # error on classification if the ground plane is flat + if 'gc_isflat' in weight_dict_ref.keys(): + # import pdb; pdb.set_trace() + self.criterion_class_isflat.to(device) + loss_dict_temp['gc_isflat'] = self.criterion_class(output_ref['isflat'], target_dict['isflat'].to(device)) + + # if we refine the shape WITHIN the refinement newtork (shaperef_type is not inexistent) + # shape regularization + # 'smal': loss on betas (pca coefficients), betas should be close to 0 + # 'limbs...' loss on selected betas_limbs + device = output_ref_comp['ref_trans_notnorm'].device + loss_shape_weighted_list = [torch.zeros((1), device=device).mean()] + if 'shape_options' in weight_dict_ref.keys(): + for ind_sp, sp in enumerate(weight_dict_ref['shape_options']): + weight_sp = weight_dict_ref['shape'][ind_sp] + # self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] + if sp == 'smal': + loss_shape_tmp = self.shape_prior(output_ref['betas']) + elif sp == 'limbs': + loss_shape_tmp = torch.mean((output_ref['betas_limbs'])**2) + elif sp == 'limbs7': + limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2] + limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device) + loss_shape_tmp = torch.mean((output_ref['betas_limbs'] * limb_coeffs[None, :])**2) + else: + raise NotImplementedError + loss_shape_weighted_list.append(weight_sp * loss_shape_tmp) + loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum() + + + + + + # 3D loss for dogs for which we have a unity model or toy figure + loss_dict_temp['models3d'] = torch.zeros((1), device=device).mean().to(output_ref['betas'].device) + if 'models3d' in weight_dict_ref.keys(): + if weight_dict_ref['models3d'] > 0: + assert (self.dog_betas_unity is not None) + if weight_dict_ref['models3d'] > 0: + for ind_dog in range(target_dict['breed_index'].shape[0]): + breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy()) + if breed_index in self.dog_betas_unity.keys(): + betas_target = self.dog_betas_unity[breed_index][:output_ref['betas'].shape[1]].to(output_ref['betas'].device) + betas_output = output_ref['betas'][ind_dog, :] + betas_limbs_output = output_ref['betas_limbs'][ind_dog, :] + loss_dict_temp['models3d'] += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_ref['betas'].shape[1] + output_ref['betas_limbs'].shape[1]) + else: + weight_dict_ref['models3d'] = 0.0 + else: + weight_dict_ref['models3d'] = 0.0 + + + + + + + + + + + + # weight the losses + loss = torch.zeros((1)).mean().to(device=output_ref['keyp_2d'].device, dtype=output_ref['keyp_2d'].dtype) + loss_dict = {} + for loss_name in weight_dict_ref.keys(): + if not loss_name in ['shape', 'shape_options']: + if weight_dict_ref[loss_name] > 0: + loss_weighted = loss_dict_temp[loss_name] * weight_dict_ref[loss_name] + loss_dict[loss_name] = loss_weighted.item() + loss += loss_weighted + loss += loss_shape_weighted + loss_dict['loss'] = loss.item() + + return loss, loss_dict + + diff --git a/src/combined_model/loss_image_to_3d_withbreedrel.py b/src/combined_model/loss_image_to_3d_withbreedrel.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c94f69dc91a80c20dff157b05fea74b4080e55 --- /dev/null +++ b/src/combined_model/loss_image_to_3d_withbreedrel.py @@ -0,0 +1,342 @@ + + +import torch +import numpy as np +import pickle as pkl + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) +# from priors.pose_prior_35 import Prior +# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior +from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior +from priors.shape_prior import ShapePrior +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa +# from configs.SMAL_configs import SMAL_MODEL_DATA_PATH, UNITY_SMAL_SHAPE_PRIOR_DOGS, SMAL_MODEL_TYPE +from configs.SMAL_configs import SMAL_MODEL_CONFIG + +from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss +from combined_model.loss_utils.loss_utils_gc import calculate_plane_errors_batch + + + +class Loss(torch.nn.Module): + def __init__(self, smal_model_type, data_info, nf_version=None): + super(Loss, self).__init__() + self.criterion_regr = torch.nn.MSELoss() # takes the mean + self.criterion_class = torch.nn.CrossEntropyLoss() + self.data_info = data_info + self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :]) + self.l_anchor = None + self.l_pos = None + self.l_neg = None + self.smal_model_type = smal_model_type + self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'] + self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs'] + + if nf_version is not None: + self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version) + self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov + self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1) + + # load 3d data for the unity dogs (an optional shape prior for 11 breeds) + if self.unity_smal_shape_prior_dogs is not None: + self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type) + else: + self.dog_betas_unity = None + + remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl' + with open(remeshing_path, 'rb') as fp: + self.remeshing_dict = pkl.load(fp) + self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long) + self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32) + + + def prepare_anchor_pos_neg(self, batch_size, device): + l0 = np.arange(0, batch_size, 2) + l_anchor = [] + l_pos = [] + l_neg = [] + for ind in l0: + xx = set(np.arange(0, batch_size)) + xx.discard(ind) + xx.discard(ind+1) + for ind2 in xx: + if ind2 % 2 == 0: + l_anchor.append(ind) + l_pos.append(ind + 1) + else: + l_anchor.append(ind + 1) + l_pos.append(ind) + l_neg.append(ind2) + self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device) + self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device) + self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device) + return + + + def forward(self, output_reproj, target_dict, weight_dict=None): + + # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image'] + # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight'] + batch_size = output_reproj['keyp_2d'].shape[0] + device = output_reproj['keyp_2d'].device + + # loss on reprojected keypoints + output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2)) + target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2)) + weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) + keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1)) + loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ + max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) + + # loss on reprojected silhouette + assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape + silh_loss_type = 'default' + if silh_loss_type == 'default': + with torch.no_grad(): + thr_silh = 20 + diff = torch.norm(output_kp_resh - target_kp_resh, dim=1) + diff_x = diff.reshape((batch_size, -1)) + weights_resh_x = weights_resh.reshape((batch_size, -1)) + unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6) + loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3]) + loss_silh = loss_silh_bs[unweighted_kp_mean_dist 0: + assert (self.dog_betas_unity is not None) + if weight_dict['models3d'] > 0: + for ind_dog in range(target_dict['breed_index'].shape[0]): + breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy()) + if breed_index in self.dog_betas_unity.keys(): + betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device) + betas_output = output_reproj['betas'][ind_dog, :] + betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :] + loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1]) + else: + weight_dict['models3d'] = 0.0 + else: + weight_dict['models3d'] = 0.0 + + # shape resularization loss on shapedirs + # -> in the current version shapedirs are kept fixed, so we don't need those losses + if weight_dict['shapedirs'] > 0: + raise NotImplementedError + else: + loss_shapedirs = torch.zeros((1), device=device).mean().to(output_reproj['betas'].device) + + # prior on back joints (not used in cvpr 2022 paper) + # -> elementwise MSE loss on all 6 coefficients of 6d rotation representation + if 'pose_0' in weight_dict.keys(): + if weight_dict['pose_0'] > 0: + pred_pose_rot6d = output_reproj['pose_rot6d'] + w_rj_np = np.zeros((pred_pose_rot6d.shape[1])) + w_rj_np[[2, 3, 4, 5]] = 1.0 # back + w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device) + zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1)) + loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None]) + else: + loss_pose = torch.zeros((1), device=device).mean() + + # pose prior + # -> we did experiment with different pose priors, for example: + # * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py, + # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py) + # * vae + # * normalizing flow pose prior + # -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below + if 'poseprior' in weight_dict.keys(): + if weight_dict['poseprior'] > 0: + pred_pose_rot6d = output_reproj['pose_rot6d'] + pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) + if 'normalizing_flow_tiger' in weight_dict['poseprior_options']: + if output_reproj['normflow_z'] is not None: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square') + else: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square') + elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']: + if output_reproj['normflow_z'] is not None: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob') + else: + loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob') + else: + raise NotImplementedError + else: + loss_poseprior = torch.zeros((1), device=device).mean() + else: + weight_dict['poseprior'] = 0 + loss_poseprior = torch.zeros((1), device=device).mean() + + # add a prior which penalizes side-movement angles for legs + if 'poselegssidemovement' in weight_dict.keys(): + if weight_dict['poselegssidemovement'] > 0: + use_pose_legs_side_loss = True + else: + use_pose_legs_side_loss = False + else: + use_pose_legs_side_loss = False + if use_pose_legs_side_loss: + leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back + leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back + vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype) + vec[2] = -1 + x0_rotmat = pred_pose + x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :] + x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :] + x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec + x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec + eps=0 # 1e-7 + # use the component of the vector which points to the side + loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean() + else: + loss_poselegssidemovement = torch.zeros((1), device=device).mean() + weight_dict['poselegssidemovement'] = 0 + + # dog breed classification loss + dog_breed_gt = target_dict['breed_index'] + dog_breed_pred = output_reproj['dog_breed'] + loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt) + + # dog breed relationship loss + # -> we did experiment with many other options, but none was significantly better + if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed + if weight_dict['breed'] > 0: + assert output_reproj['dog_breed'].shape[0] == 12 + # assert weight_dict['breed'] > 0 + z = output_reproj['z'] + # go through all pairs and compare them to each other sample + if self.l_anchor is None: + self.prepare_anchor_pos_neg(batch_size, z.device) + anchor = torch.index_select(z, 0, self.l_anchor) + positive = torch.index_select(z, 0, self.l_pos) + negative = torch.index_select(z, 0, self.l_neg) + loss_breed = self.criterion_triplet(anchor, positive, negative) + else: + loss_breed = torch.zeros((1), device=device).mean() + else: + loss_breed = torch.zeros((1), device=device).mean() + + # regularizarion for focal length + loss_flength_near_mean = torch.mean(output_reproj['flength']**2) + loss_flength = loss_flength_near_mean + + # bodypart segmentation loss + if 'partseg' in weight_dict.keys(): + if weight_dict['partseg'] > 0: + raise NotImplementedError + else: + loss_partseg = torch.zeros((1), device=device).mean() + else: + weight_dict['partseg'] = 0 + loss_partseg = torch.zeros((1), device=device).mean() + + + # NEW: ground contact loss for main network + keep_smal_mesh = False + if 'gc_plane' in weight_dict.keys(): + if weight_dict['gc_plane'] > 0: + if keep_smal_mesh: + target_gc_class = target_dict['gc'][:, :, 0] + gc_errors_plane = calculate_plane_errors_batch(output_reproj['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching']) + loss_gc_plane = torch.mean(gc_errors_plane) + else: # use a uniformly sampled mesh + target_gc_class = target_dict['gc'][:, :, 0] + device = output_reproj['vertices_smal'].device + remeshing_relevant_faces = self.remeshing_relevant_faces.to(device) + remeshing_relevant_barys = self.remeshing_relevant_barys.to(device) + + bs = output_reproj['vertices_smal'].shape[0] + # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_reproj['vertices_smal'][:, self.remeshing_relevant_faces]) + # sel_verts_comparison = output_reproj['vertices_smal'][:, self.remeshing_relevant_faces] + # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison) + sel_verts = torch.index_select(output_reproj['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3)) + verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts) + target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32)) + target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) + gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching']) + loss_gc_plane = torch.mean(gc_errors_plane) + loss_gc_belowplane = torch.mean(gc_errors_under_plane) + # loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane) + else: + loss_gc_plane = torch.zeros((1), device=device).mean() + loss_gc_belowplane = torch.zeros((1), device=device).mean() + else: + loss_gc_plane = torch.zeros((1), device=device).mean() + loss_gc_belowplane = torch.zeros((1), device=device).mean() + weight_dict['gc_plane'] = 0 + weight_dict['gc_belowplane'] = 0 + + + + # weight and combine losses + loss_keyp_weighted = loss_keyp * weight_dict['keyp'] + loss_silh_weighted = loss_silh * weight_dict['silh'] + loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs'] + loss_pose_weighted = loss_pose * weight_dict['pose_0'] + loss_class_weighted = loss_class * weight_dict['class'] + loss_breed_weighted = loss_breed * weight_dict['breed'] + loss_flength_weighted = loss_flength * weight_dict['flength'] + loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior'] + loss_partseg_weighted = loss_partseg * weight_dict['partseg'] + loss_models3d_weighted = loss_models3d * weight_dict['models3d'] + loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement'] + + loss_gc_plane_weighted = loss_gc_plane * weight_dict['gc_plane'] + loss_gc_belowplane_weighted = loss_gc_belowplane * weight_dict['gc_belowplane'] + + + #################################################################################################### + loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \ + loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \ + loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted + \ + loss_gc_plane_weighted + loss_gc_belowplane_weighted + #################################################################################################### + + loss_dict = {'loss': loss.item(), + 'loss_keyp_weighted': loss_keyp_weighted.item(), \ + 'loss_silh_weighted': loss_silh_weighted.item(), \ + 'loss_shape_weighted': loss_shape_weighted.item(), \ + 'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \ + 'loss_pose0_weighted': loss_pose_weighted.item(), \ + 'loss_class_weighted': loss_class_weighted.item(), \ + 'loss_breed_weighted': loss_breed_weighted.item(), \ + 'loss_flength_weighted': loss_flength_weighted.item(), \ + 'loss_poseprior_weighted': loss_poseprior_weighted.item(), \ + 'loss_partseg_weighted': loss_partseg_weighted.item(), \ + 'loss_models3d_weighted': loss_models3d_weighted.item(), \ + 'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item(), \ + 'loss_gc_plane_weighted': loss_gc_plane_weighted.item(), \ + 'loss_gc_belowplane_weighted': loss_gc_belowplane_weighted.item() + } + + return loss, loss_dict + + + + diff --git a/src/combined_model/loss_utils/loss_arap.py b/src/combined_model/loss_utils/loss_arap.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccc61a6ab40b4dc9b706489e6c78e9bf3c4f449 --- /dev/null +++ b/src/combined_model/loss_utils/loss_arap.py @@ -0,0 +1,153 @@ +import torch + +# code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py + + +class Arap_Loss(): + ''' + Pytorch implementaion: As-rigid-as-possible loss class + + ''' + + def __init__(self, meshes, device='cpu', vertex_w=None): + + with torch.no_grad(): # new nadine + + self.device = device + self.bn = len(meshes) + + # get lapacian cotangent matrix + L = self.get_laplacian_cot(meshes) + self.wij = L.values().clone() + self.wij[self.wij<0] = 0. + + # get ajacency matrix + V = meshes.num_verts_per_mesh().sum() + edges_packed = meshes.edges_packed() + e0, e1 = edges_packed.unbind(1) + idx01 = torch.stack([e0, e1], dim=1) + idx10 = torch.stack([e1, e0], dim=1) + idx = torch.cat([idx01, idx10], dim=0).t() + + ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device) + A = torch.sparse.FloatTensor(idx, ones, (V, V)) + self.deg = torch.sparse.sum(A, dim=1).to_dense().long() + self.idx = self.sort_idx(idx) + + # get edges of default mesh + self.eij = self.get_edges(meshes) + + # get per vertex regularization strength + self.vertex_w = vertex_w + + + def __call__(self, new_meshes): + new_meshes._compute_packed() + + optimal_R = self.step_1(new_meshes) + arap_loss = self.step_2(optimal_R, new_meshes) + return arap_loss + + + def step_1(self, new_meshes): + bn = self.bn + eij = self.eij.view(bn, -1, 3).cpu() + + with torch.no_grad(): + eij_ = self.get_edges(new_meshes) + + eij_ = eij_.view(bn, -1, 3).cpu() + wij = self.wij.view(bn, -1).cpu() + + deg_1 = self.deg.view(bn, -1)[0].cpu() # assuming same topology + S = torch.zeros([bn, len(deg_1), 3, 3]) + for i in range(len(deg_1)): + start, end = deg_1[:i].sum(), deg_1[:i+1].sum() + P = eij[:, start : end] + P_ = eij_[:, start : end] + D = wij[:, start : end] + D = torch.diag_embed(D) + S[:, i] = P.transpose(-2,-1) @ D @ P_ + + S = S.view(-1, 3, 3) + + u, _, v = torch.svd(S) + R = v @ u.transpose(-2, -1) + det = torch.det(R) + + u[det<0, :, -1] *= -1 + R = v @ u.transpose(-2, -1) + R = R.to(self.device) + + return R + + + def step_2(self, R, new_meshes): + R = torch.repeat_interleave(R, self.deg, dim=0) + Reij = R @ self.eij.unsqueeze(2) + Reij = Reij.squeeze() + + eij_ = self.get_edges(new_meshes) + arap_loss = self.wij * (eij_ - Reij).norm(dim=1) + + if self.vertex_w is not None: + vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0) + arap_loss = arap_loss * vertex_w + + arap_loss = arap_loss.sum() / self.bn + + return arap_loss + + + def get_edges(self, meshes): + verts_packed = meshes.verts_packed() + vi = torch.repeat_interleave(verts_packed, self.deg, dim=0) + vj = verts_packed[self.idx[1]] + eij = vi - vj + return eij + + + def sort_idx(self, idx): + _, order = (idx[0] + idx[1]*1e-6).sort() + + return idx[:, order] + + + def get_laplacian_cot(self, meshes): + ''' + Routine modified from : + pytorch3d/loss/mesh_laplacian_smoothing.py + ''' + verts_packed = meshes.verts_packed() + faces_packed = meshes.faces_packed() + V, F = verts_packed.shape[0], faces_packed.shape[0] + + face_verts = verts_packed[faces_packed] + v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2] + + A = (v1-v2).norm(dim=1) + B = (v0-v2).norm(dim=1) + C = (v0-v1).norm(dim=1) + + s = 0.5 * (A+B+C) + area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt() + + A2, B2, C2 = A * A, B * B, C * C + cota = (B2 + C2 - A2) / area + cotb = (A2 + C2 - B2) / area + cotc = (A2 + B2 - C2) / area + cot = torch.stack([cota, cotb, cotc], dim=1) + cot /= 4.0 + + ii = faces_packed[:, [1,2,0]] + jj = faces_packed[:, [2,0,1]] + idx = torch.stack([ii, jj], dim=0).view(2, F*3) + L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) + L += L.t() + L = L.coalesce() + L /= 2.0 # normalized according to arap paper + + return L + + + diff --git a/src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py b/src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c87dc43f30150b7dfff23930c49a9735230040 --- /dev/null +++ b/src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py @@ -0,0 +1,45 @@ +# code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_utils.py + +import numpy as np +import torch + + +# Laplacian loss, calculate the Laplacian coordiante of both coarse and refined vertices and then compare the difference +class LaplacianCTF(torch.nn.Module): + def __init__(self, adjmat, device): + ''' + Args: + adjmat: adjacency matrix of the input graph data + device: specify device for training + ''' + super(LaplacianCTF, self).__init__() + adjmat.data = np.ones_like(adjmat.data) + adjmat = torch.from_numpy(adjmat.todense()).float() + dg = torch.sum(adjmat, dim=-1) + dg_m = torch.diag(dg) + ls = dg_m - adjmat + self.ls = ls.unsqueeze(0).to(device) # Should be normalized by the diagonal elements according to + # the origial definition, this one also works fine. + + def forward(self, verts_pred, verts_gt, smooth=False): + verts_pred = torch.matmul(self.ls, verts_pred) + verts_gt = torch.matmul(self.ls, verts_gt) + loss = torch.norm(verts_pred - verts_gt, dim=-1).mean() + if smooth: + loss_smooth = torch.norm(torch.matmul(self.ls, verts_pred), dim=-1).mean() + return loss, loss_smooth + return loss, None + + + + +# +# read the adjacency matrix, which will used in the Laplacian regularizer +# data = np.load('./data/mesh_down_sampling_4.npz', encoding='latin1', allow_pickle=True) +# adjmat = data['A'][0] +# laplacianloss = Laplacian(adjmat, device) +# +# verts_clone = verts.detach().clone() +# loss_arap, loss_smooth = laplacianloss(verts_refine, verts_clone) +# loss_arap = args.w_arap * loss_arap +# \ No newline at end of file diff --git a/src/combined_model/loss_utils/loss_sdf.py b/src/combined_model/loss_utils/loss_sdf.py new file mode 100644 index 0000000000000000000000000000000000000000..60c19179c86d5896cf8ca227c928c0653e44692e --- /dev/null +++ b/src/combined_model/loss_utils/loss_sdf.py @@ -0,0 +1,122 @@ + +# code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_sdf.py + +import torch +import numpy as np +from scipy.ndimage import distance_transform_edt as distance +from skimage import segmentation as skimage_seg +import matplotlib.pyplot as plt + + +def dice_loss(score, target): + # implemented from paper https://arxiv.org/pdf/1606.04797.pdf + target = target.float() + smooth = 1e-5 + intersect = torch.sum(score * target) + y_sum = torch.sum(target * target) + z_sum = torch.sum(score * score) + loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) + loss = 1 - loss + return loss + + +class tversky_loss(torch.nn.Module): + # implemented from https://arxiv.org/pdf/1706.05721.pdf + def __init__(self, alpha, beta): + ''' + Args: + alpha: coefficient for false positive prediction + beta: coefficient for false negtive prediction + ''' + super(tversky_loss, self).__init__() + self.alpha = alpha + self.beta = beta + + def __call__(self, score, target): + target = target.float() + smooth = 1e-5 + tp = torch.sum(score * target) + fn = torch.sum(target * (1 - score)) + fp = torch.sum((1-target) * score) + loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth) + loss = 1 - loss + return loss + + +def compute_sdf1_1(img_gt, out_shape): + """ + compute the normalized signed distance map of binary mask + input: segmentation, shape = (batch_size, x, y, z) + output: the Signed Distance Map (SDM) + sdf(x) = 0; x in segmentation boundary + -inf|x-y|; x in segmentation + +inf|x-y|; x out of segmentation + normalize sdf to [-1, 1] + """ + + img_gt = img_gt.astype(np.uint8) + + normalized_sdf = np.zeros(out_shape) + + for b in range(out_shape[0]): # batch size + # ignore background + for c in range(1, out_shape[1]): + posmask = img_gt[b] + negmask = 1-posmask + posdis = distance(posmask) + negdis = distance(negmask) + boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) + sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) + sdf[boundary==1] = 0 + normalized_sdf[b][c] = sdf + assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) + assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) + + return normalized_sdf + + +def compute_sdf(img_gt, out_shape): + """ + compute the signed distance map of binary mask + input: segmentation, shape = (batch_size, x, y, z) + output: the Signed Distance Map (SDM) + sdf(x) = 0; x in segmentation boundary + -inf|x-y|; x in segmentation + +inf|x-y|; x out of segmentation + """ + + img_gt = img_gt.astype(np.uint8) + + gt_sdf = np.zeros(out_shape) + debug = False + for b in range(out_shape[0]): # batch size + for c in range(0, out_shape[1]): + posmask = img_gt[b] + negmask = 1-posmask + posdis = distance(posmask) + negdis = distance(negmask) + boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) + sdf = negdis - posdis + sdf[boundary==1] = 0 + gt_sdf[b][c] = sdf + if debug: + plt.figure() + plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar() + plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar() + plt.show() + + return gt_sdf + + +def boundary_loss(output, gt): + """ + compute boundary loss for binary segmentation + input: outputs_soft: softmax results, shape=(b,2,x,y,z) + gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z) + output: boundary_loss; sclar + adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf + """ + multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt) + bd_loss = multipled.mean() + + return bd_loss \ No newline at end of file diff --git a/src/combined_model/loss_utils/loss_utils.py b/src/combined_model/loss_utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eea046139a12f3864142b53c88ec059e66f8af77 --- /dev/null +++ b/src/combined_model/loss_utils/loss_utils.py @@ -0,0 +1,191 @@ + +import torch +import numpy as np + + +''' +def keyp_rep_error_l1(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3): + # step 1: make sure that the hg prediction and barc are close + with torch.no_grad(): + kp_weights = keyp_hourglass_scores + kp_weights[keyp_hourglass_scores 0: + # bug corrected 07.11.22 + # error_under_plane = nonplane_points_projected[nonplane_points_projected<0].sum() / 100 + error_under_plane = - nonplane_points_projected[nonplane_points_projected<0].sum() / 100 + else: + error_under_plane = nonplane_points_projected[nonplane_points_projected>0].sum() / 100 + error_under_plane_list.append(error_under_plane) + except: + print('was not able to calculate plane error for this image') + error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0]) + error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0]) + else: + error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0]) + error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0]) + errors = torch.stack(error_list, dim=0) + errors_under_plane = torch.stack(error_under_plane_list, dim=0) + + if return_error_under_plane: + return errors, errors_under_plane + else: + return errors + + + +# def calculate_vertex_wise_labeling_error(): + # vertexwise_ground_contact + + + + + + + + + + + + + +''' + +def paws_to_groundplane_error_batch(vertices, return_details=False): + # list of feet vertices (some of them) + # remark: we did annotate left indices and find the right insices using sym_ids_dict + # REMARK: this loss is not yet for batches! + import pdb; pdb.set_trace() + print('this loss is not yet for batches!') + list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569] + list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420] + list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521] + list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372] + assert vertices.shape[0] == 3889 + assert vertices.shape[1] == 3 + all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right + verts_paws = vertices[all_paw_vert_idxs, :] + plane_centroid, plane_normal, error = fit_plane_batch(verts_paws) + if return_details: + return plane_centroid, plane_normal, error + else: + return error + +def paws_to_groundplane_error_batch_new(vertices, return_details=False): + # list of feet vertices (some of them) + # remark: we did annotate left indices and find the right insices using sym_ids_dict + # REMARK: this loss is not yet for batches! + import pdb; pdb.set_trace() + print('this loss is not yet for batches!') + list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569] + list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420] + list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521] + list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372] + assert vertices.shape[0] == 3889 + assert vertices.shape[1] == 3 + all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right + verts_paws = vertices[all_paw_vert_idxs, :] + plane_centroid, plane_normal, error = fit_plane_batch(verts_paws) + print('this loss is not yet for batches!') + points = torch.transpose(points_npx3, 0, 1) # (3, n_points) + points_centroid = torch.mean(points, dim=1) + input_svd = points - points_centroid[:, None] + U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True) + plane_normal = U_svd[:, 2] + plane_squaredsumofdists = sigma_svd[2] + error = plane_squaredsumofdists + print('error: ' + str(error.item())) + return error +''' \ No newline at end of file diff --git a/src/combined_model/model_shape_v7_withref_withgraphcnn.py b/src/combined_model/model_shape_v7_withref_withgraphcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..58a6c886c2e706e385ae2163e222f35a723bef9a --- /dev/null +++ b/src/combined_model/model_shape_v7_withref_withgraphcnn.py @@ -0,0 +1,927 @@ + +import pickle as pkl +import numpy as np +import torchvision.models as models +from torchvision import transforms +import torch +from torch import nn +from torch.nn.parameter import Parameter +from kornia.geometry.subpix import dsnt # kornia 0.4.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from stacked_hourglass.utils.evaluation import get_preds_soft +from stacked_hourglass import hg1, hg2, hg8 +from lifting_to_3d.linear_model import LinearModelComplete, LinearModel +from lifting_to_3d.inn_model_for_shape import INNForShape +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d +from smal_pytorch.smal_model.smal_torch_new import SMAL +from smal_pytorch.renderer.differentiable_renderer import SilhRenderer +from bps_2d.bps_for_segmentation import SegBPS +# from configs.SMAL_configs import SMAL_MODEL_DATA_PATH as SHAPE_PRIOR +from configs.SMAL_configs import SMAL_MODEL_CONFIG +from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL + +# NEW: for graph cnn part +from smal_pytorch.smal_model.smal_torch_new import SMAL +from configs.SMAL_configs import SMAL_MODEL_CONFIG +from graph_networks.graphcmr.utils_mesh import Mesh +from graph_networks.graphcmr.graph_cnn_groundcontact_multistage import GraphCNNMS + + + + +class SmallLinear(nn.Module): + def __init__(self, input_size=64, output_size=30, linear_size=128): + super(SmallLinear, self).__init__() + self.relu = nn.ReLU(inplace=True) + self.w1 = nn.Linear(input_size, linear_size) + self.w2 = nn.Linear(linear_size, linear_size) + self.w3 = nn.Linear(linear_size, output_size) + def forward(self, x): + # pre-processing + y = self.w1(x) + y = self.relu(y) + y = self.w2(y) + y = self.relu(y) + y = self.w3(y) + return y + + +class MyConv1d(nn.Module): + def __init__(self, input_size=37, output_size=30, start=True): + super(MyConv1d, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.start = start + self.weight = Parameter(torch.ones((self.output_size))) + self.bias = Parameter(torch.zeros((self.output_size))) + def forward(self, x): + # pre-processing + if self.start: + y = x[:, :self.output_size] + else: + y = x[:, -self.output_size:] + y = y * self.weight[None, :] + self.bias[None, :] + return y + + +class ModelShapeAndBreed(nn.Module): + def __init__(self, smal_model_type, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'): + super(ModelShapeAndBreed, self).__init__() + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs # n_betas_logscale + self.n_breeds = n_breeds + self.structure_z_to_betas = structure_z_to_betas + if self.structure_z_to_betas == '1dconv': + if not (n_z == self.n_betas+self.n_betas_limbs): + raise ValueError + self.smal_model_type = smal_model_type + # shape branch + self.resnet = models.resnet34(pretrained=False) + # replace the first layer + n_in = 3 + 1 + self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # replace the last layer + self.resnet.fc = nn.Linear(512, n_z) + # softmax + self.soft_max = torch.nn.Softmax(dim=1) + # fc network (and other versions) to connect z with betas + p_dropout = 0.2 + if self.structure_z_to_betas == 'default': + self.linear_betas = LinearModel(linear_size=1024, + num_stage=1, + p_dropout=p_dropout, + input_size=n_z, + output_size=self.n_betas) + self.linear_betas_limbs = LinearModel(linear_size=1024, + num_stage=1, + p_dropout=p_dropout, + input_size=n_z, + output_size=self.n_betas_limbs) + elif self.structure_z_to_betas == 'lin': + self.linear_betas = nn.Linear(n_z, self.n_betas) + self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs) + elif self.structure_z_to_betas == 'fc_0': + self.linear_betas = SmallLinear(linear_size=128, # 1024, + input_size=n_z, + output_size=self.n_betas) + self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024, + input_size=n_z, + output_size=self.n_betas_limbs) + elif structure_z_to_betas == 'fc_1': + self.linear_betas = LinearModel(linear_size=64, # 1024, + num_stage=1, + p_dropout=0, + input_size=n_z, + output_size=self.n_betas) + self.linear_betas_limbs = LinearModel(linear_size=64, # 1024, + num_stage=1, + p_dropout=0, + input_size=n_z, + output_size=self.n_betas_limbs) + elif self.structure_z_to_betas == '1dconv': + self.linear_betas = MyConv1d(n_z, self.n_betas, start=True) + self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False) + elif self.structure_z_to_betas == 'inn': + self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0) + else: + raise ValueError + # network to connect latent shape vector z with dog breed classification + self.linear_breeds = LinearModel(linear_size=1024, # 1024, + num_stage=1, + p_dropout=p_dropout, + input_size=n_z, + output_size=self.n_breeds) + # shape multiplicator + self.shape_multiplicator_np = np.ones(self.n_betas) + with open(SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'], 'rb') as file: + u = pkl._Unpickler(file) + u.encoding = 'latin1' + res = u.load() + # shape predictions are centered around the mean dog of our dog model + if 'dog_cluster_mean' in res.keys(): + self.betas_mean_np = res['dog_cluster_mean'] + else: + assert res['cluster_means'].shape[0]==1 + self.betas_mean_np = res['cluster_means'][0, :] + + + def forward(self, img, seg_raw=None, seg_prep=None): + # img is the network input image + # seg_raw is before softmax and subtracting 0.5 + # seg_prep would be the prepared_segmentation + if seg_prep is None: + seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5 + input_img_and_seg = torch.cat((img, seg_prep), axis=1) + res_output = self.resnet(input_img_and_seg) + dog_breed_output = self.linear_breeds(res_output) + if self.structure_z_to_betas == 'inn': + shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output) + else: + shape_output_orig = self.linear_betas(res_output) * 0.1 + betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device) + shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas] + shape_limbs_output_orig = self.linear_betas_limbs(res_output) + shape_limbs_output = shape_limbs_output_orig * 0.1 + output_dict = {'z': res_output, + 'breeds': dog_breed_output, + 'betas': shape_output_orig, + 'betas_limbs': shape_limbs_output_orig} + return output_dict + + + +class LearnableShapedirs(nn.Module): + def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10): + super(LearnableShapedirs, self).__init__() + # shapedirs_init = self.smal.shapedirs.detach() + self.n_betas = n_betas + self.n_betas_fixed = n_betas_fixed + self.sym_ids_dict = sym_ids_dict + sym_left_ids = self.sym_ids_dict['left'] + sym_right_ids = self.sym_ids_dict['right'] + sym_center_ids = self.sym_ids_dict['center'] + self.n_center = sym_center_ids.shape[0] + self.n_left = sym_left_ids.shape[0] + self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs + # get indices to go from half_shapedirs to shapedirs + inds_back = np.zeros((3889)) + for ind in range(0, sym_center_ids.shape[0]): + ind_in_forward = sym_center_ids[ind] + inds_back[ind_in_forward] = ind + for ind in range(0, sym_left_ids.shape[0]): + ind_in_forward = sym_left_ids[ind] + inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind + for ind in range(0, sym_right_ids.shape[0]): + ind_in_forward = sym_right_ids[ind] + inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind + self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long()) + # self.smal.shapedirs: (51, 11667) + # shapedirs: (3889, 3, n_sd) + # shapedirs_half: (2012, 3, n_sd) + sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas)) + self.register_buffer('sd', sd) + sd_center = sd[sym_center_ids, :, self.n_betas_fixed:] + sd_left = sd[sym_left_ids, :, self.n_betas_fixed:] + self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach())) + self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach())) + self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach())) + self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach())) + self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach())) + def forward(self): + device = self.learnable_half_shapedirs_c0.device + half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \ + torch.zeros((self.n_center, self.n_sd)).to(device), \ + self.learnable_half_shapedirs_c2), axis=1) + half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \ + self.learnable_half_shapedirs_l1, \ + self.learnable_half_shapedirs_l2), axis=1) + half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \ + - self.learnable_half_shapedirs_l1, \ + self.learnable_half_shapedirs_l2), axis=1) + half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right)) + shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch) + shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd) + shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667) + return shapedirs_complete, shapedirs_complete_prepared + + +class ModelRefinement(nn.Module): + def __init__(self, n_betas=10, n_betas_limbs=7, n_breeds=121, n_keyp=20, n_joints=35, ref_net_type='add', graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'): + super(ModelRefinement, self).__init__() + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_breeds = n_breeds + self.n_keyp = n_keyp + self.n_joints = n_joints + self.n_out_seg = 256 + self.n_out_keyp = 256 + self.n_out_enc = 256 + self.linear_size = 1024 + self.linear_size_small = 128 + self.ref_net_type = ref_net_type + self.graphcnn_type = graphcnn_type + self.isflat_type = isflat_type + self.shaperef_type = shaperef_type + p_dropout = 0.2 + # --- segmentation encoder + if self.ref_net_type in ['multrot_res34', 'multrot01all_res34']: + self.ref_res = models.resnet34(pretrained=False) + else: + self.ref_res = models.resnet18(pretrained=False) + # replace the first layer + self.ref_res.conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # replace the last layer + self.ref_res.fc = nn.Linear(512, self.n_out_seg) + # softmax + self.soft_max = torch.nn.Softmax(dim=1) + # --- keypoint encoder + self.linear_keyp = LinearModel(linear_size=self.linear_size, + num_stage=1, + p_dropout=p_dropout, + input_size=n_keyp*2*2, + output_size=self.n_out_keyp) + # --- decoder + self.linear_combined = LinearModel(linear_size=self.linear_size, + num_stage=1, + p_dropout=p_dropout, + input_size=self.n_out_seg+self.n_out_keyp, + output_size=self.n_out_enc) + # output info + pose = {'name': 'pose', 'n': self.n_joints*6, 'out_shape':[self.n_joints, 6]} + trans = {'name': 'trans_notnorm', 'n': 3} + cam = {'name': 'flength_notnorm', 'n': 1} + betas = {'name': 'betas', 'n': self.n_betas} + betas_limbs = {'name': 'betas_limbs', 'n': self.n_betas_limbs} + if self.shaperef_type=='inexistent': + self.output_info = [pose, trans, cam] # , betas] + else: + self.output_info = [pose, trans, cam, betas, betas_limbs] + # output branches + self.output_info_linear_models = [] + for ind_el, element in enumerate(self.output_info): + n_in = self.n_out_enc + element['n'] + self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, + num_stage=1, + p_dropout=p_dropout, + input_size=n_in, + output_size=element['n'])) + element['linear_model_index'] = ind_el + self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models) + # new: predict if the ground is flat + if not self.isflat_type=='inexistent': + self.linear_isflat = LinearModel(linear_size=self.linear_size_small, + num_stage=1, + p_dropout=p_dropout, + input_size=self.n_out_enc, + output_size=2) # answer is just yes or no + + + # new for ground contact prediction: graph cnn + if not self.graphcnn_type=='inexistent': + num_downsampling = 1 + smal_model_type = '39dogs_norm' + smal = SMAL(smal_model_type=smal_model_type, template_name='neutral') + ROOT_smal_downsampling = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/' + smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path']).replace('.pkl', '_template.npz') + smal_downsampling_npz_path = ROOT_smal_downsampling + smal_downsampling_npz_name # 'data/mesh_downsampling.npz' + self.my_custom_smal_dog_mesh = Mesh(filename=smal_downsampling_npz_path, num_downsampling=num_downsampling, nsize=1, body_model=smal) # , device=device) + # create GraphCNN + num_layers = 2 # <= len(my_custom_mesh._A)-1 + n_resnet_out = self.n_out_enc # 256 + num_channels = 256 # 512 + self.graph_cnn = GraphCNNMS(mesh=self.my_custom_smal_dog_mesh, + num_downsample = num_downsampling, + num_layers = num_layers, + n_resnet_out = n_resnet_out, + num_channels = num_channels) # .to(device) + + + + def forward(self, keyp_sh, keyp_pred, in_pose_3x3, in_trans_notnorm, in_cam_notnorm, in_betas, in_betas_limbs, seg_pred_prep=None, seg_sh_raw=None, seg_sh_prep=None): + # img is the network input image + # seg_raw is before softmax and subtracting 0.5 + # seg_prep would be the prepared_segmentation + batch_size = in_pose_3x3.shape[0] + device = in_pose_3x3.device + dtype = in_pose_3x3.dtype + # --- segmentation encoder + if seg_sh_prep is None: + seg_sh_prep = self.soft_max(seg_sh_raw)[:, 1:2, :, :] - 0.5 # class 1 is the dog + input_seg_conc = torch.cat((seg_sh_prep, seg_pred_prep), axis=1) + network_output_seg = self.ref_res(input_seg_conc) + # --- keypoint encoder + keyp_conc = torch.cat((keyp_sh.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2])), keyp_pred.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2]))), axis=1) + network_output_keyp = self.linear_keyp(keyp_conc) + # --- decoder + x = torch.cat((network_output_seg, network_output_keyp), axis=1) + y_comb = self.linear_combined(x) + in_pose_6d = rotmat_to_rot6d(in_pose_3x3.reshape((-1, 3, 3))).reshape((in_pose_3x3.shape[0], -1, 6)) + in_dict = {'pose': in_pose_6d, + 'trans_notnorm': in_trans_notnorm, + 'flength_notnorm': in_cam_notnorm, + 'betas': in_betas, + 'betas_limbs': in_betas_limbs} + results = {} + for element in self.output_info: + # import pdb; pdb.set_trace() + + linear_model = self.output_info_linear_models[element['linear_model_index']] + y = torch.cat((y_comb, in_dict[element['name']].reshape((-1, element['n']))), axis=1) + if 'out_shape' in element.keys(): + if element['name'] == 'pose': + if self.ref_net_type in ['multrot', 'multrot01', 'multrot01all', 'multrotxx', 'multrot_res34', 'multrot01all_res34']: # if self.ref_net_type == 'multrot' or self.ref_net_type == 'multrot_res34': + # multiply the rotations with each other -> just predict a correction + # the correction should be initialized as identity + # res_pose_out = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']] + identity_rot6d = torch.tensor(([1., 0., 0., 1., 0., 0.])).repeat((in_pose_3x3.shape[0]*in_pose_3x3.shape[1], 1)).to(device=device, dtype=dtype) + if self.ref_net_type in ['multrot01', 'multrot01all', 'multrot01all_res34']: + res_pose_out = identity_rot6d + 0.1*(linear_model(y)).reshape((-1, element['out_shape'][1])) + elif self.ref_net_type == 'multrotxx': + res_pose_out = identity_rot6d + 0.0*(linear_model(y)).reshape((-1, element['out_shape'][1])) + else: + res_pose_out = identity_rot6d + (linear_model(y)).reshape((-1, element['out_shape'][1])) + res_pose_rotmat = rot6d_to_rotmat(res_pose_out.reshape((-1, 6))) # (bs*35, 3, 3) .reshape((batch_size, -1, 3, 3)) + res_tot_rotmat = torch.bmm(res_pose_rotmat.reshape((-1, 3, 3)), in_pose_3x3.reshape((-1, 3, 3))).reshape((batch_size, -1, 3, 3)) # (bs, 5, 3, 3) + results['pose_rotmat'] = res_tot_rotmat + elif self.ref_net_type == 'add': + res_6d = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict['pose'] + results['pose_rotmat'] = rot6d_to_rotmat(res_6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) + else: + raise ValueError + else: + if self.ref_net_type in ['multrot01all', 'multrot01all_res34']: + results[element['name']] = (0.1*linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']] + else: + results[element['name']] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']] + else: + if self.ref_net_type in ['multrot01all', 'multrot01all_res34']: + results[element['name']] = 0.1*linear_model(y) + in_dict[element['name']] + else: + results[element['name']] = linear_model(y) + in_dict[element['name']] + + # add prediction if ground is flat + if not self.isflat_type=='inexistent': + isflat = self.linear_isflat(y_comb) + results['isflat'] = isflat + + # add graph cnn + if not self.graphcnn_type=='inexistent': + ground_contact_downsampled, ground_cntact_all_stages_output = self.graph_cnn(y_comb) + ground_contact = self.my_custom_smal_dog_mesh.upsample(ground_contact_downsampled.transpose(1,2)) + results['vertexwise_ground_contact'] = ground_contact + + return results + + + + +class ModelImageToBreed(nn.Module): + def __init__(self, smal_model_type, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True): + super(ModelImageToBreed, self).__init__() + self.n_classes = n_classes + self.n_partseg = n_partseg + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_keyp = n_keyp + self.n_bones = n_bones + self.n_breeds = n_breeds + self.image_size = image_size + self.upsample_seg = True + self.threshold_scores = thr_keyp_sc + self.n_z = n_z + self.add_partseg = add_partseg + self.smal_model_type = smal_model_type + # ------------------------------ STACKED HOUR GLASS ------------------------------ + if arch == 'hg8': + self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) + else: + raise Exception('unrecognised model architecture: ' + arch) + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + self.breed_model = ModelShapeAndBreed(smal_model_type=self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z) + def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): + batch_size = input_img.shape[0] + device = input_img.device + # ------------------------------ STACKED HOUR GLASS ------------------------------ + hourglass_out_dict = self.stacked_hourglass(input_img) + last_seg = hourglass_out_dict['seg_final'] + last_heatmap = hourglass_out_dict['out_list_kp'][-1] + # - prepare keypoints (from heatmap) + # normalize predictions -> from logits to probability distribution + # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) + # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) + # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) + keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) + if self.threshold_scores is not None: + scores[scores>self.threshold_scores] = 1.0 + scores[scores<=self.threshold_scores] = 0.0 + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + # breed_model takes as input the image as well as the predicted segmentation map + # -> we need to split up ModelImageTo3d, such that we can use the silhouette + resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) + pred_breed = resnet_output['breeds'] # (bs, n_breeds) + pred_betas = resnet_output['betas'] + pred_betas_limbs = resnet_output['betas_limbs'] + small_output = {'keypoints_norm': keypoints_norm, + 'keypoints_scores': scores} + small_output_reproj = {'betas': pred_betas, + 'betas_limbs': pred_betas_limbs, + 'dog_breed': pred_breed} + return small_output, None, small_output_reproj + +class ModelImageTo3d_withshape_withproj(nn.Module): + def __init__(self, smal_model_type, smal_keyp_conf=None, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None, ref_net_type='add', ref_detach_shape=True, graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'): + super(ModelImageTo3d_withshape_withproj, self).__init__() + self.n_classes = n_classes + self.n_partseg = n_partseg + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_keyp = n_keyp + self.n_joints = n_joints + self.n_bones = n_bones + self.n_breeds = n_breeds + self.image_size = image_size + self.threshold_scores = thr_keyp_sc + self.upsample_seg = True + self.silh_no_tail = silh_no_tail + self.add_z_to_3d_input = add_z_to_3d_input + self.add_segbps_to_3d_input = add_segbps_to_3d_input + self.add_partseg = add_partseg + self.ref_net_type = ref_net_type + self.ref_detach_shape = ref_detach_shape + self.graphcnn_type = graphcnn_type + self.isflat_type = isflat_type + self.shaperef_type = shaperef_type + assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input) + self.n_z = n_z + if add_segbps_to_3d_input: + self.n_segbps = n_segbps # 64 + self.segbps_model = SegBPS() + else: + self.n_segbps = 0 + self.fix_flength = fix_flength + self.render_partseg = render_partseg + self.structure_z_to_betas = structure_z_to_betas + self.structure_pose_net = structure_pose_net + assert self.structure_pose_net in ['default', 'vae', 'normflow'] + self.nf_version = nf_version + self.smal_model_type = smal_model_type + assert (smal_keyp_conf is not None) + self.smal_keyp_conf = smal_keyp_conf + self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas))) + self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32)) + p_dropout = 0.2 # 0.5 + # ------------------------------ SMAL MODEL ------------------------------ + self.smal = SMAL(smal_model_type=self.smal_model_type, template_name='neutral') + print('SMAL model type: ' + self.smal.smal_model_type) + # New for rendering without tail + f_np = self.smal.faces.detach().cpu().numpy() + self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :] + # in theory we could optimize for improved shapedirs, but we do not do that + # -> would need to implement regularizations + # -> there are better ways than changing the shapedirs + self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10) + # ------------------------------ STACKED HOUR GLASS ------------------------------ + if arch == 'hg8': + self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) + else: + raise Exception('unrecognised model architecture: ' + arch) + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + self.breed_model = ModelShapeAndBreed(self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas) + # ------------------------------ LINEAR 3D MODEL ------------------------------ + # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} + self.soft_max = torch.nn.Softmax(dim=1) + input_size = self.n_keyp*3 + self.n_bones + self.model_3d = LinearModelComplete(linear_size=1024, + num_stage_comb=num_stage_comb, + num_stage_heads=num_stage_heads, + num_stage_heads_pose=num_stage_heads_pose, + trans_sep=trans_sep, + p_dropout=p_dropout, # 0.5, + input_size=input_size, + intermediate_size=1024, + output_info=None, + n_joints=self.n_joints, + n_z=self.n_z, + add_z_to_3d_input=self.add_z_to_3d_input, + n_segbps=self.n_segbps, + add_segbps_to_3d_input=self.add_segbps_to_3d_input, + structure_pose_net=self.structure_pose_net, + nf_version = self.nf_version) + # ------------------------------ RENDERING ------------------------------ + self.silh_renderer = SilhRenderer(image_size) + # ------------------------------ REFINEMENT ----------------------------- + self.refinement_model = ModelRefinement(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_keyp=self.n_keyp, n_joints=self.n_joints, ref_net_type=self.ref_net_type, graphcnn_type=self.graphcnn_type, isflat_type=self.isflat_type, shaperef_type=self.shaperef_type) + + + def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): + batch_size = input_img.shape[0] + device = input_img.device + # ------------------------------ STACKED HOUR GLASS ------------------------------ + hourglass_out_dict = self.stacked_hourglass(input_img) + last_seg = hourglass_out_dict['seg_final'] + last_heatmap = hourglass_out_dict['out_list_kp'][-1] + # - prepare keypoints (from heatmap) + # normalize predictions -> from logits to probability distribution + # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) + # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) + # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) + keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) + if self.threshold_scores is not None: + scores[scores>self.threshold_scores] = 1.0 + scores[scores<=self.threshold_scores] = 0.0 + # ------------------------------ LEARNABLE SHAPE MODEL ------------------------------ + # in our cvpr 2022 paper we do not change the shapedirs + # learnable_sd_complete has shape (3889, 3, n_sd) + # learnable_sd_complete_prepared has shape (n_sd, 11667) + learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs() + shapedirs_sel = learnable_sd_complete_prepared # None + # ------------------------------ SHAPE AND BREED MODEL ------------------------------ + # breed_model takes as input the image as well as the predicted segmentation map + # -> we need to split up ModelImageTo3d, such that we can use the silhouette + resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) + pred_breed = resnet_output['breeds'] # (bs, n_breeds) + pred_z = resnet_output['z'] + # - prepare shape + pred_betas = resnet_output['betas'] + pred_betas_limbs = resnet_output['betas_limbs'] + # - calculate bone lengths + with torch.no_grad(): + use_mean_bone_lengths = False + if use_mean_bone_lengths: + bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) + else: + assert (bone_lengths_prepared is None) + bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True) + # ------------------------------ LINEAR 3D MODEL ------------------------------ + # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} + # prepare input for 2d-to-3d network + keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2) + if bone_lengths_prepared is None: + bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) + # should we add silhouette to 3d input? should we add z? + if self.add_segbps_to_3d_input: + seg_raw = last_seg + seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog + with torch.no_grad(): + seg_prep_np = seg_prep_bps.detach().cpu().numpy() + bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2) + bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1)) + bps_output_prep = bps_output * 2. - 1 + input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1) + elif self.add_z_to_3d_input: + # we do not use this in our cvpr 2022 version + input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + input_vec_additional = pred_z + input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1) + else: + input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + # predict 3d parameters (those are normalized, we need to correct mean and std in a next step) + output = self.model_3d(input_vec) + # add predicted keypoints to the output dict + output['keypoints_norm'] = keypoints_norm + output['keypoints_scores'] = scores + # add predicted segmentation to output dictc + output['seg_hg'] = hourglass_out_dict['seg_final'] + # - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again + pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) + if self.structure_pose_net == 'default': + pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :] + elif self.structure_pose_net == 'normflow': + pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) + pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] + pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros + else: + pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) + pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] + pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros + pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))) + pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3)) + pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6)) + + if self.fix_flength: + output['flength'] = torch.zeros_like(output['flength']) + pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :] + else: + pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) + pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig) + pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] + + # ------------------------------ RENDERING ------------------------------ + # get 3d model (SMAL) + V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=shapedirs_sel) + keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) + # render silhouette + faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + if not self.silh_no_tail: + pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, + points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength) + else: + faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) + pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, + points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength) + # get torch 'Meshes' + torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep) + + # render body parts (not part of cvpr 2022 version) + if self.render_partseg: + raise NotImplementedError + else: + partseg_images = None + partseg_images_hg = None + + + # ------------------------------ REFINEMENT MODEL ------------------------------ + + # refinement model + pred_keyp_norm = (pred_keyp.detach() / (self.image_size - 1) - 0.5)*2 + '''output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \ + seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \ + in_pose=output['pose'].detach(), in_trans=output['trans'].detach(), in_cam=output['flength'].detach(), in_betas=pred_betas.detach())''' + output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \ + seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \ + in_pose_3x3=pred_pose.detach(), in_trans_notnorm=output['trans'].detach(), in_cam_notnorm=output['flength'].detach(), in_betas=pred_betas.detach(), in_betas_limbs=pred_betas_limbs.detach()) + # a better alternative would be to submit pred_pose_reshx33 + + + + # nothing changes for betas or shapedirs or z ##################### should probably not be detached in the end + if self.shaperef_type == 'inexistent': + if self.ref_detach_shape: + output_ref['betas'] = pred_betas.detach() + output_ref['betas_limbs'] = pred_betas_limbs.detach() + output_ref['z'] = pred_z.detach() + output_ref['shapedirs'] = shapedirs_sel.detach() + else: + output_ref['betas'] = pred_betas + output_ref['betas_limbs'] = pred_betas_limbs + output_ref['z'] = pred_z + output_ref['shapedirs'] = shapedirs_sel + else: + assert ('betas' in output_ref.keys()) + assert ('betas_limbs' in output_ref.keys()) + output_ref['shapedirs'] = shapedirs_sel + + + # we denormalize flength and trans, but pose is handled differently + if self.fix_flength: + output_ref['flength_notnorm'] = torch.zeros_like(output['flength']) + ref_pred_flength = torch.ones_like(output['flength_notnorm'])*2100 # norm_dict['flength_mean'][None, :] + raise ValueError # not sure if we want to have a fixed flength in refinement + else: + ref_pred_flength_orig = output_ref['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) + ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig) + ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] + ref_pred_trans = output_ref['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) + + + + + # ref_pred_pose_rot6d = output_ref['pose'] + # ref_pred_pose_reshx33 = rot6d_to_rotmat(output_ref['pose'].reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) + ref_pred_pose_reshx33 = output_ref['pose_rotmat'].reshape((batch_size, -1, 3, 3)) + ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6)) + + ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref['betas'], betas_limbs=output_ref['betas_limbs'], + pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, + shapedirs_sel=output_ref['shapedirs']) + ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) + + if not self.silh_no_tail: + faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, + points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength) + else: + faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) + ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, + points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength) + + output_ref_unnorm = {'vertices_smal': ref_V, + 'keyp_3d': ref_keyp_3d, + 'keyp_2d': ref_pred_keyp, + 'silh': ref_pred_silh_images, + 'trans': ref_pred_trans, + 'flength': ref_pred_flength, + 'betas': output_ref['betas'], + 'betas_limbs': output_ref['betas_limbs'], + # 'z': output_ref['z'], + 'pose_rot6d': ref_pred_pose_rot6d, + 'pose_rotmat': ref_pred_pose_reshx33} + # 'shapedirs': shapedirs_sel} + + if not self.graphcnn_type == 'inexistent': + output_ref_unnorm['vertexwise_ground_contact'] = output_ref['vertexwise_ground_contact'] + if not self.isflat_type=='inexistent': + output_ref_unnorm['isflat'] = output_ref['isflat'] + if self.shaperef_type == 'inexistent': + output_ref_unnorm['z'] = output_ref['z'] + + # REMARK: we will want to have the predicted differences, for pose this would + # be a rotation matrix, ... + # -> TODO: adjust output_orig_ref_comparison + output_orig_ref_comparison = {#'pose': output['pose'].detach(), + #'trans': output['trans'].detach(), + #'flength': output['flength'].detach(), + # 'pose': output['pose'], + 'old_pose_rotmat': pred_pose_reshx33, + 'old_trans_notnorm': output['trans'], + 'old_flength_notnorm': output['flength'], + # 'ref_pose': output_ref['pose'], + 'ref_pose_rotmat': ref_pred_pose_reshx33, + 'ref_trans_notnorm': output_ref['trans_notnorm'], + 'ref_flength_notnorm': output_ref['flength_notnorm']} + + + + # ------------------------------ PREPARE OUTPUT ------------------------------ + # create output dictionarys + # output: contains all output from model_image_to_3d + # output_unnorm: same as output, but normalizations are undone + # output_reproj: smal output and reprojected keypoints as well as silhouette + keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1) + output_unnorm = {'pose_rotmat': pred_pose, + 'flength': pred_flength, + 'trans': pred_trans, + 'keypoints':keypoints_heatmap_256} + output_reproj = {'vertices_smal': V, + 'torch_meshes': torch_meshes, + 'keyp_3d': keyp_3d, + 'keyp_2d': pred_keyp, + 'silh': pred_silh_images, + 'betas': pred_betas, + 'betas_limbs': pred_betas_limbs, + 'pose_rot6d': pred_pose_rot6d, # used for pose prior... + 'dog_breed': pred_breed, + 'shapedirs': shapedirs_sel, + 'z': pred_z, + 'flength_unnorm': pred_flength, + 'flength': output['flength'], + 'partseg_images_rend': partseg_images, + 'partseg_images_hg_nograd': partseg_images_hg, + 'normflow_z': output['normflow_z']} + + return output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison + + + def forward_with_multiple_refinements(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): + + # import pdb; pdb.set_trace() + + # run normal network part + output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison = self.forward(input_img, norm_dict=norm_dict, bone_lengths_prepared=bone_lengths_prepared, betas=betas) + + # prepare input for second refinement stage + batch_size = output['keypoints_norm'].shape[0] + keypoints_norm = output['keypoints_norm'] + pred_keyp_norm = (output_ref_unnorm['keyp_2d'].detach() / (self.image_size - 1) - 0.5)*2 + + last_seg = output['seg_hg'] + pred_silh_images = output_ref_unnorm['silh'].detach() + + trans_notnorm = output_orig_ref_comparison['ref_trans_notnorm'] + flength_notnorm = output_orig_ref_comparison['ref_flength_notnorm'] + # trans_notnorm = output_orig_ref_comparison['ref_pose_rotmat'] + pred_pose = output_ref_unnorm['pose_rotmat'].reshape((batch_size, -1, 3, 3)) + + # run second refinement step + output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \ + seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \ + in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), \ + in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach()) + # output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach()) + + + # new shape + if self.shaperef_type == 'inexistent': + if self.ref_detach_shape: + output_ref_new['betas'] = output_ref_unnorm['betas'].detach() + output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs'].detach() + output_ref_new['z'] = output_ref_unnorm['z'].detach() + output_ref_new['shapedirs'] = output_reproj['shapedirs'].detach() + else: + output_ref_new['betas'] = output_ref_unnorm['betas'] + output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs'] + output_ref_new['z'] = output_ref_unnorm['z'] + output_ref_new['shapedirs'] = output_reproj['shapedirs'] + else: + assert ('betas' in output_ref_new.keys()) + assert ('betas_limbs' in output_ref_new.keys()) + output_ref_new['shapedirs'] = output_reproj['shapedirs'] + + # we denormalize flength and trans, but pose is handled differently + if self.fix_flength: + raise ValueError # not sure if we want to have a fixed flength in refinement + else: + ref_pred_flength_orig = output_ref_new['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) + ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig) + ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] + ref_pred_trans = output_ref_new['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) + + + ref_pred_pose_reshx33 = output_ref_new['pose_rotmat'].reshape((batch_size, -1, 3, 3)) + ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6)) + + ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'], + pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, + shapedirs_sel=output_ref_new['shapedirs']) + + # ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'], pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=output_ref_new['shapedirs']) + ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) + + if not self.silh_no_tail: + faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, + points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength) + else: + faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) + ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, + points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength) + + output_ref_unnorm_new = {'vertices_smal': ref_V, + 'keyp_3d': ref_keyp_3d, + 'keyp_2d': ref_pred_keyp, + 'silh': ref_pred_silh_images, + 'trans': ref_pred_trans, + 'flength': ref_pred_flength, + 'betas': output_ref_new['betas'], + 'betas_limbs': output_ref_new['betas_limbs'], + 'pose_rot6d': ref_pred_pose_rot6d, + 'pose_rotmat': ref_pred_pose_reshx33} + + if not self.graphcnn_type == 'inexistent': + output_ref_unnorm_new['vertexwise_ground_contact'] = output_ref_new['vertexwise_ground_contact'] + if not self.isflat_type=='inexistent': + output_ref_unnorm_new['isflat'] = output_ref_new['isflat'] + if self.shaperef_type == 'inexistent': + output_ref_unnorm_new['z'] = output_ref_new['z'] + + output_orig_ref_comparison_new = {'ref_pose_rotmat': ref_pred_pose_reshx33, + 'ref_trans_notnorm': output_ref_new['trans_notnorm'], + 'ref_flength_notnorm': output_ref_new['flength_notnorm']} + + results = { + 'output': output, + 'output_unnorm': output_unnorm, + 'output_reproj':output_reproj, + 'output_ref_unnorm': output_ref_unnorm, + 'output_orig_ref_comparison':output_orig_ref_comparison, + 'output_ref_unnorm_new': output_ref_unnorm_new, + 'output_orig_ref_comparison_new': output_orig_ref_comparison_new} + return results + + + + + + + + + + + + + + + + + + + + + + def render_vis_nograd(self, vertices, focal_lengths, color=0): + # this function is for visualization only + # vertices: (bs, n_verts, 3) + # focal_lengths: (bs, 1) + # color: integer, either 0 or 1 + # returns a torch tensor of shape (bs, image_size, image_size, 3) + with torch.no_grad(): + batch_size = vertices.shape[0] + faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) + visualizations = self.silh_renderer.get_visualization_nograd(vertices, + faces_prep, focal_lengths, color=color) + return visualizations + diff --git a/src/combined_model/train_main_image_to_3d_wbr_withref.py b/src/combined_model/train_main_image_to_3d_wbr_withref.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca1fc2f372036d36209fca26ed09446a5a934c8 --- /dev/null +++ b/src/combined_model/train_main_image_to_3d_wbr_withref.py @@ -0,0 +1,955 @@ + +import torch +import torch.nn as nn +import torch.backends.cudnn +import torch.nn.parallel +from tqdm import tqdm +import os +import pathlib +from matplotlib import pyplot as plt +import cv2 +import numpy as np +import torch +import trimesh +import pickle as pkl +import csv +from scipy.spatial.transform import Rotation as R_sc + + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image +from metrics.metrics import Metrics +from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS, SMAL_KEYPOINT_NAMES_FOR_3D_EVAL, SMAL_KEYPOINT_INDICES_FOR_3D_EVAL, SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL +from combined_model.helper import eval_save_visualizations_and_meshes, eval_prepare_pck_and_iou, eval_add_preds_to_summary + +from smal_pytorch.smal_model.smal_torch_new import SMAL # for gc visualization +from src.combined_model.loss_utils.loss_utils import fit_plane +# from src.evaluation.sketchfab_evaluation.alignment_utils.calculate_v2v_error_release import compute_similarity_transform +# from src.evaluation.sketchfab_evaluation.alignment_utils.calculate_alignment_error import calculate_alignemnt_errors + +# --------------------------------------------------------------------------------------------------------------------------- +def do_training_epoch(train_loader, model, loss_module, loss_module_ref, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None, weight_dict_ref=None): + losses = AverageMeter() + losses_keyp = AverageMeter() + losses_silh = AverageMeter() + losses_shape = AverageMeter() + losses_pose = AverageMeter() + losses_class = AverageMeter() + losses_breed = AverageMeter() + losses_partseg = AverageMeter() + losses_ref_keyp = AverageMeter() + losses_ref_silh = AverageMeter() + losses_ref_pose = AverageMeter() + losses_ref_reg = AverageMeter() + accuracies = AverageMeter() + # Put the model in training mode. + model.train() + # prepare progress bar + iterable = enumerate(train_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False) + iterable = progress + # information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + # prepare variables, put them on the right device + for i, (input, target_dict) in iterable: + batch_size = input.shape[0] + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key in ['has_seg', 'gc']: + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # ----------------------- do training step ----------------------- + assert model.training, 'model must be in training mode.' + with torch.enable_grad(): + # ----- forward pass ----- + output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict) + # ----- loss ----- + # --- from main network + loss, loss_dict = loss_module(output_reproj=output_reproj, + target_dict=target_dict, + weight_dict=weight_dict) + # ---from refinement network + loss_ref, loss_dict_ref = loss_module_ref(output_ref=output_ref, + output_ref_comp=output_ref_comp, + target_dict=target_dict, + weight_dict_ref=weight_dict_ref) + loss_total = loss + loss_ref + # ----- backward pass and parameter update ----- + optimiser.zero_grad() + loss_total.backward() + optimiser.step() + # ---------------------------------------------------------------- + + # prepare losses for progress bar + bs_fake = 1 # batch_size + losses.update(loss_dict['loss'] + loss_dict_ref['loss'], bs_fake) + losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake) + losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake) + losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake) + losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake) + losses_class.update(loss_dict['loss_class_weighted'], bs_fake) + losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake) + losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake) + losses_ref_keyp.update(loss_dict_ref['keyp_ref'], bs_fake) + losses_ref_silh.update(loss_dict_ref['silh_ref'], bs_fake) + loss_ref_pose = 0 + for l_name in ['pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_side', 'pose_spine_tors']: + if l_name in loss_dict_ref.keys(): + loss_ref_pose += loss_dict_ref[l_name] + losses_ref_pose.update(loss_ref_pose, bs_fake) + loss_ref_reg = 0 + for l_name in ['reg_trans', 'reg_flength', 'reg_pose']: + if l_name in loss_dict_ref.keys(): + loss_ref_reg += loss_dict_ref[l_name] + losses_ref_reg.update(loss_ref_reg, bs_fake) + acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model' + accuracies.update(acc, bs_fake) + # Show losses as part of the progress bar. + if progress is not None: + my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_shape=losses_shape.avg, + loss_pose=losses_pose.avg, + loss_class=losses_class.avg, + loss_breed=losses_breed.avg, + loss_partseg=losses_partseg.avg, + loss_ref_keyp=losses_ref_keyp.avg, + loss_ref_silh=losses_ref_silh.avg, + loss_ref_pose=losses_ref_pose.avg, + loss_ref_reg=losses_ref_reg.avg) + my_string_short = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_ref_keyp=losses_ref_keyp.avg, + loss_ref_silh=losses_ref_silh.avg, + loss_ref_pose=losses_ref_pose.avg, + loss_ref_reg=losses_ref_reg.avg) + progress.set_postfix_str(my_string_short) + + return my_string, accuracies.avg + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_validation_epoch(val_loader, model, loss_module, loss_module_ref, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, weight_dict_ref=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None): + losses = AverageMeter() + losses_keyp = AverageMeter() + losses_silh = AverageMeter() + losses_shape = AverageMeter() + losses_pose = AverageMeter() + losses_class = AverageMeter() + losses_breed = AverageMeter() + losses_partseg = AverageMeter() + losses_ref_keyp = AverageMeter() + losses_ref_silh = AverageMeter() + losses_ref_pose = AverageMeter() + losses_ref_reg = AverageMeter() + accuracies = AverageMeter() + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + # Put the model in evaluation mode. + model.eval() + # prepare progress bar + iterable = enumerate(val_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False) + iterable = progress + # summarize information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + batch_size = val_loader.batch_size + + return_mesh_with_gt_groundplane = True + if return_mesh_with_gt_groundplane: + remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl' + with open(remeshing_path, 'rb') as fp: + remeshing_dict = pkl.load(fp) + remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device) + remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device) + + + # from smal_pytorch.smal_model.smal_torch_new import SMAL + print('start: load smal default model (barc), but only for vertices') + smal = SMAL() + print('end: load smal default model (barc), but only for vertices') + smal_template_verts = smal.v_template.detach().cpu().numpy() + smal_faces = smal.faces.detach().cpu().numpy() + + + my_step = 0 + for index, (input, target_dict) in iterable: + + # prepare variables, put them on the right device + curr_batch_size = input.shape[0] + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key in ['has_seg', 'gc']: + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # ----------------------- do validation step ----------------------- + with torch.no_grad(): + # ----- forward pass ----- + # output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores']) + # output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints']) + # output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength']) + # target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh']) + output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict) + # ----- loss ----- + if metrics == 'no_loss': + # --- from main network + loss, loss_dict = loss_module(output_reproj=output_reproj, + target_dict=target_dict, + weight_dict=weight_dict) + # ---from refinement network + loss_ref, loss_dict_ref = loss_module_ref(output_ref=output_ref, + output_ref_comp=output_ref_comp, + target_dict=target_dict, + weight_dict_ref=weight_dict_ref) + loss_total = loss + loss_ref + + # ---------------------------------------------------------------- + + + for result_network in ['normal', 'ref']: + # variabled that are not refined + hg_keyp_norm = output['keypoints_norm'] + hg_keyp_scores = output['keypoints_scores'] + betas = output_reproj['betas'] + betas_limbs = output_reproj['betas_limbs'] + zz = output_reproj['z'] + if result_network == 'normal': + # STEP 1: normal network + vertices_smal = output_reproj['vertices_smal'] + flength = output_unnorm['flength'] + pose_rotmat = output_unnorm['pose_rotmat'] + trans = output_unnorm['trans'] + pred_keyp = output_reproj['keyp_2d'] + pred_silh = output_reproj['silh'] + prefix = 'normal_' + else: + # STEP 1: refinement network + vertices_smal = output_ref['vertices_smal'] + flength = output_ref['flength'] + pose_rotmat = output_ref['pose_rotmat'] + trans = output_ref['trans'] + pred_keyp = output_ref['keyp_2d'] + pred_silh = output_ref['silh'] + prefix = 'ref_' + if return_mesh_with_gt_groundplane and 'gc' in target_dict.keys(): + bs = vertices_smal.shape[0] + target_gc_class = target_dict['gc'][:, :, 0] + sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3)) + verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts) + target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32)) + target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) + + + + + + # import pdb; pdb.set_trace() + + # new for vertex wise ground contact + if (not model.graphcnn_type == 'inexistent') and (save_imgs_path is not None): + # import pdb; pdb.set_trace() + + sm = torch.nn.Softmax(dim=2) + ground_contact_probs = sm(output_ref['vertexwise_ground_contact']) + + for ind_img in range(ground_contact_probs.shape[0]): + # ind_img = 0 + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + out_path_gcmesh = save_imgs_path + '/' + prefix + 'gcmesh_' + img_name + '.obj' + + gc_prob = ground_contact_probs[ind_img, :, 1] # contact probability + vert_colors = np.repeat(255*gc_prob.detach().cpu().numpy()[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=smal_template_verts, faces=smal_faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + save_gc_mesh = True # False + if save_gc_mesh: + my_mesh.export(out_path_gcmesh) + + ''' + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean,data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + out_path = save_debug_path + 'b' + str(ind_img) +'_input.png' + plt.imsave(out_path, input_image_np) + ''' + + # ------------------------------------- + + # import pdb; pdb.set_trace() + + + ''' + target_gc_class = target_dict['gc'][ind_img, :, 0] + + current_vertices_smal = vertices_smal[ind_img, :, :] + + points_centroid, plane_normal, error = fit_plane(current_vertices_smal[target_gc_class==1, :]) + ''' + + # calculate ground plane + # (see /is/cluster/work/nrueegg/icon_pifu_related/ICON/debug_code/curve_fitting_v2.py) + if return_mesh_with_gt_groundplane and 'gc' in target_dict.keys(): + + current_verts_remeshed = verts_remeshed[ind_img, :, :] + current_target_gc_class_remeshed_prep = target_gc_class_remeshed_prep[ind_img, ...] + + if current_target_gc_class_remeshed_prep.sum() > 3: + points_on_plane = current_verts_remeshed[current_target_gc_class_remeshed_prep==1, :] + data_centroid, plane_normal, error = fit_plane(points_on_plane) + nonplane_points_centered = current_verts_remeshed[current_target_gc_class_remeshed_prep==0, :] - data_centroid[None, :] + nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered.transpose(0,1)) + + if nonplane_points_projected.sum() > 0: # plane normal points towards the animal + plane_normal = plane_normal.detach().cpu().numpy() + else: + plane_normal = - plane_normal.detach().cpu().numpy() + data_centroid = data_centroid.detach().cpu().numpy() + + + + # import pdb; pdb.set_trace() + + + desired_plane_normal_vector = np.asarray([[0, -1, 0]]) + # new approach: use cross product + rotation_axis = np.cross(plane_normal, desired_plane_normal_vector) # np.cross(plane_normal, desired_plane_normal_vector) + lengt_rotation_axis = np.linalg.norm(rotation_axis) # = sin(alpha) (because vectors have unit length) + angle = np.sin(lengt_rotation_axis) + rot = R_sc.from_rotvec(angle * rotation_axis * 1/lengt_rotation_axis) + rot_mat = rot[0].as_matrix() + rot_upsidedown = R_sc.from_rotvec(np.pi * np.asarray([[1, 0, 0]])) + # rot_upsidedown[0].apply(rot[0].apply(plane_normal)) + current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy() + new_smal_vertices = rot_upsidedown[0].apply(rot[0].apply(current_vertices_smal - data_centroid[None, :])) + my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True) + vert_colors[:, 2] = 255 + my_mesh.visual.vertex_colors = vert_colors + out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_new.obj' + my_mesh.export(out_path_gc_rotated) + + + + + + + '''# rot = R_sc.align_vectors(plane_normal.reshape((1, -1)), desired_plane_normal_vector) + desired_plane_normal_vector = np.asarray([[0, 1, 0]]) + + rot = R_sc.align_vectors(desired_plane_normal_vector, plane_normal.reshape((1, -1))) # inv + rot_mat = rot[0].as_matrix() + + + current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy() + new_smal_vertices = rot[0].apply((current_vertices_smal - data_centroid[None, :])) + + my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_y.obj' + my_mesh.export(out_path_gc_rotated) + ''' + + + + + + + + + + # ---- + + + # ------------------------------------- + + + + + if index == 0: + if len_dataset is None: + len_data = val_loader.batch_size * len(val_loader) # 1703 + else: + len_data = len_dataset + if metrics == 'all' or metrics == 'no_loss': + if result_network == 'normal': + summaries = {'normal': dict(), 'ref': dict()} + summary = summaries['normal'] + else: + summary = summaries['ref'] + summary['pck'] = np.zeros((len_data)) + summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS} + summary['acc_sil_2d'] = np.zeros(len_data) + summary['betas'] = np.zeros((len_data,betas.shape[1])) + summary['betas_limbs'] = np.zeros((len_data, betas_limbs.shape[1])) + summary['z'] = np.zeros((len_data, zz.shape[1])) + summary['pose_rotmat'] = np.zeros((len_data, pose_rotmat.shape[1], 3, 3)) + summary['flength'] = np.zeros((len_data, flength.shape[1])) + summary['trans'] = np.zeros((len_data, trans.shape[1])) + summary['breed_indices'] = np.zeros((len_data)) + summary['image_names'] = [] # len_data * [None] + else: + if result_network == 'normal': + summary = summaries['normal'] + else: + summary = summaries['ref'] + + if save_imgs_path is not None: + eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=render_all) + + if metrics == 'all' or metrics == 'no_loss': + preds = eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=progress) + # add results for all images in this batch to lists + curr_batch_size = pred_keyp.shape[0] + eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size) + else: + # measure accuracy and record loss + bs_fake = 1 # batch_size + # import pdb; pdb.set_trace() + + + # save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png' + # import pdb; pdb.set_trace() + ''' + for ind_img in range(len(target_dict['index'])): + try: + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + all_image_names = ['keypoints_pred_' + img_name + '.png', 'normal_comp_pred_' + img_name + '.png', 'normal_rot_tex_pred_' + img_name + '.png', 'ref_comp_pred_' + img_name + '.png', 'ref_rot_tex_pred_' + img_name + '.png'] + all_saved_images = [] + for sub_img_name in all_image_names: + saved_img = cv2.imread(save_imgs_path + '/' + sub_img_name) + if not (saved_img.shape[0] == 256 and saved_img.shape[1] == 256): + saved_img = cv2.resize(saved_img, (256, 256)) + all_saved_images.append(saved_img) + final_image = np.concatenate(all_saved_images, axis=1) + save_imgs_path_sum = save_imgs_path.replace('test_', 'summary_test_') + if not os.path.exists(save_imgs_path_sum): os.makedirs(save_imgs_path_sum) + final_image_path = save_imgs_path_sum + '/summary_' + img_name + '.png' + cv2.imwrite(final_image_path, final_image) + except: + print('dont save a summary image') + ''' + + + bs_fake = 1 + if metrics == 'all' or metrics == 'no_loss': + # update progress bar + if progress is not None: + '''my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format( + pck[:(my_step * batch_size + curr_batch_size)].mean(), + acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean())''' + my_string = "normal_PCK: {0:.2f}, normal_IOU: {1:.2f}, ref_PCK: {2:.2f}, ref_IOU: {3:.2f}".format( + summaries['normal']['pck'][:(my_step * batch_size + curr_batch_size)].mean(), + summaries['normal']['acc_sil_2d'][:(my_step * batch_size + curr_batch_size)].mean(), + summaries['ref']['pck'][:(my_step * batch_size + curr_batch_size)].mean(), + summaries['ref']['acc_sil_2d'][:(my_step * batch_size + curr_batch_size)].mean()) + progress.set_postfix_str(my_string) + else: + losses.update(loss_dict['loss'] + loss_dict_ref['loss'], bs_fake) + losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake) + losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake) + losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake) + losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake) + losses_class.update(loss_dict['loss_class_weighted'], bs_fake) + losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake) + losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake) + losses_ref_keyp.update(loss_dict_ref['keyp_ref'], bs_fake) + losses_ref_silh.update(loss_dict_ref['silh_ref'], bs_fake) + loss_ref_pose = 0 + for l_name in ['pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_side', 'pose_spine_tors']: + loss_ref_pose += loss_dict_ref[l_name] + losses_ref_pose.update(loss_ref_pose, bs_fake) + loss_ref_reg = 0 + for l_name in ['reg_trans', 'reg_flength', 'reg_pose']: + loss_ref_reg += loss_dict_ref[l_name] + losses_ref_reg.update(loss_ref_reg, bs_fake) + acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model' + accuracies.update(acc, bs_fake) + # Show losses as part of the progress bar. + if progress is not None: + my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_shape=losses_shape.avg, + loss_pose=losses_pose.avg, + loss_class=losses_class.avg, + loss_breed=losses_breed.avg, + loss_partseg=losses_partseg.avg, + loss_ref_keyp=losses_ref_keyp.avg, + loss_ref_silh=losses_ref_silh.avg, + loss_ref_pose=losses_ref_pose.avg, + loss_ref_reg=losses_ref_reg.avg) + my_string_short = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_ref_keyp=losses_ref_keyp.avg, + loss_ref_silh=losses_ref_silh.avg, + loss_ref_pose=losses_ref_pose.avg, + loss_ref_reg=losses_ref_reg.avg) + progress.set_postfix_str(my_string_short) + my_step += 1 + if metrics == 'all': + return my_string, summaries # summary + elif metrics == 'no_loss': + return my_string, np.average(np.asarray(summaries['ref']['acc_sil_2d'])) # np.average(np.asarray(summary['acc_sil_2d'])) + else: + return my_string, accuracies.avg + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, weight_dict_ref=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False, len_dataset=None): + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + all_results = [] + + # Put the model in evaluation mode. + model.eval() + + iterable = enumerate(val_loader) + + # information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + + + return_mesh_with_gt_groundplane = True + if return_mesh_with_gt_groundplane: + remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl' + with open(remeshing_path, 'rb') as fp: + remeshing_dict = pkl.load(fp) + remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device) + remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device) + + # from smal_pytorch.smal_model.smal_torch_new import SMAL + print('start: load smal default model (barc), but only for vertices') + smal = SMAL() + print('end: load smal default model (barc), but only for vertices') + smal_template_verts = smal.v_template.detach().cpu().numpy() + smal_faces = smal.faces.detach().cpu().numpy() + + file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.txt', 'a') # append mode + file_alignment_errors.write(" ----------- start evaluation ------------- \n ") + + csv_file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.csv', 'w') # write mode + fieldnames = ['name', 'error'] + writer = csv.DictWriter(csv_file_alignment_errors, fieldnames=fieldnames) + writer.writeheader() + + my_step = 0 + for index, (input, target_dict) in iterable: + batch_size = input.shape[0] + input = input.float().to(device) + partial_results = {} + + # ----------------------- do visualization step ----------------------- + with torch.no_grad(): + output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict) + + + # import pdb; pdb.set_trace() + + + sm = torch.nn.Softmax(dim=2) + ground_contact_probs = sm(output_ref['vertexwise_ground_contact']) + + for result_network in ['normal', 'ref']: + # variabled that are not refined + hg_keyp_norm = output['keypoints_norm'] + hg_keyp_scores = output['keypoints_scores'] + betas = output_reproj['betas'] + betas_limbs = output_reproj['betas_limbs'] + zz = output_reproj['z'] + if result_network == 'normal': + # STEP 1: normal network + vertices_smal = output_reproj['vertices_smal'] + flength = output_unnorm['flength'] + pose_rotmat = output_unnorm['pose_rotmat'] + trans = output_unnorm['trans'] + pred_keyp = output_reproj['keyp_2d'] + pred_silh = output_reproj['silh'] + prefix = 'normal_' + else: + # STEP 1: refinement network + vertices_smal = output_ref['vertices_smal'] + flength = output_ref['flength'] + pose_rotmat = output_ref['pose_rotmat'] + trans = output_ref['trans'] + pred_keyp = output_ref['keyp_2d'] + pred_silh = output_ref['silh'] + prefix = 'ref_' + + bs = vertices_smal.shape[0] + # target_gc_class = target_dict['gc'][:, :, 0] + target_gc_class = torch.round(ground_contact_probs).long()[:, :, 1] + sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3)) + verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts) + target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32)) + target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) + + + + + # index = i + # ind_img = 0 + for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size): + + # ind_img = 0 + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + out_path_gcmesh = save_imgs_path + '/' + prefix + 'gcmesh_' + img_name + '.obj' + + gc_prob = ground_contact_probs[ind_img, :, 1] # contact probability + vert_colors = np.repeat(255*gc_prob.detach().cpu().numpy()[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=smal_template_verts, faces=smal_faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + save_gc_mesh = False + if save_gc_mesh: + my_mesh.export(out_path_gcmesh) + + current_verts_remeshed = verts_remeshed[ind_img, :, :] + current_target_gc_class_remeshed_prep = target_gc_class_remeshed_prep[ind_img, ...] + + if current_target_gc_class_remeshed_prep.sum() > 3: + points_on_plane = current_verts_remeshed[current_target_gc_class_remeshed_prep==1, :] + data_centroid, plane_normal, error = fit_plane(points_on_plane) + nonplane_points_centered = current_verts_remeshed[current_target_gc_class_remeshed_prep==0, :] - data_centroid[None, :] + nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered.transpose(0,1)) + + if nonplane_points_projected.sum() > 0: # plane normal points towards the animal + plane_normal = plane_normal.detach().cpu().numpy() + else: + plane_normal = - plane_normal.detach().cpu().numpy() + data_centroid = data_centroid.detach().cpu().numpy() + + + + # import pdb; pdb.set_trace() + + + desired_plane_normal_vector = np.asarray([[0, -1, 0]]) + # new approach: use cross product + rotation_axis = np.cross(plane_normal, desired_plane_normal_vector) # np.cross(plane_normal, desired_plane_normal_vector) + lengt_rotation_axis = np.linalg.norm(rotation_axis) # = sin(alpha) (because vectors have unit length) + angle = np.sin(lengt_rotation_axis) + rot = R_sc.from_rotvec(angle * rotation_axis * 1/lengt_rotation_axis) + rot_mat = rot[0].as_matrix() + rot_upsidedown = R_sc.from_rotvec(np.pi * np.asarray([[1, 0, 0]])) + # rot_upsidedown[0].apply(rot[0].apply(plane_normal)) + current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy() + new_smal_vertices = rot_upsidedown[0].apply(rot[0].apply(current_vertices_smal - data_centroid[None, :])) + my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True) + vert_colors[:, 2] = 255 + my_mesh.visual.vertex_colors = vert_colors + out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_new.obj' + my_mesh.export(out_path_gc_rotated) + + + + ''' + import pdb; pdb.set_trace() + + from src.evaluation.registration import preprocess_point_cloud, o3d_ransac, draw_registration_result + import open3d as o3d + import copy + + + mesh_gt_path = target_dict['mesh_path'][ind_img] + mesh_gt = o3d.io.read_triangle_mesh(mesh_gt_path) + + mesh_gt_verts = np.asarray(mesh_gt.vertices) + mesh_gt_faces = np.asarray(mesh_gt.triangles) + diag_gt = np.sqrt(sum((mesh_gt_verts.max(axis=0) - mesh_gt_verts.min(axis=0))**2)) + + mesh_pred_verts = np.asarray(new_smal_vertices) + mesh_pred_faces = np.asarray(smal_faces) + diag_pred = np.sqrt(sum((mesh_pred_verts.max(axis=0) - mesh_pred_verts.min(axis=0))**2)) + mesh_pred = o3d.geometry.TriangleMesh() + mesh_pred.vertices = o3d.utility.Vector3dVector(mesh_pred_verts) + mesh_pred.triangles = o3d.utility.Vector3iVector(mesh_pred_faces) + + # center the predicted mesh around 0 + trans = - mesh_pred_verts.mean(axis=0) + mesh_pred_verts_new = mesh_pred_verts + trans + # change the size of the predicted mesh + mesh_pred_verts_new = mesh_pred_verts_new * diag_gt / diag_pred + + # transform the predicted mesh (rough alignment) + mesh_pred_new = copy.deepcopy(mesh_pred) + mesh_pred_new.vertices = o3d.utility.Vector3dVector(np.asarray(mesh_pred_verts_new)) # normals should not have changed + voxel_size = 0.01 # 0.5 + distance_threshold = 0.015 # 0.005 # 0.02 # 1.0 + result, src_down, src_fpfh, dst_down, dst_fpfh = o3d_ransac(mesh_pred_new, mesh_gt, voxel_size=voxel_size, distance_threshold=distance_threshold, return_all=True) + transform = result.transformation + mesh_pred_transf = copy.deepcopy(mesh_pred_new).transform(transform) + + out_path_pred_transf = save_imgs_path + '/' + prefix + 'alignment_initial_' + img_name + '.obj' + o3d.io.write_triangle_mesh(out_path_pred_transf, mesh_pred_transf) + + # img_name_part = img_name.split(img_name.split('_')[-1] + '_')[0] + # out_path_gt = save_imgs_path + '/' + prefix + 'ground_truth_' + img_name_part + '.obj' + # o3d.io.write_triangle_mesh(out_path_gt, mesh_gt) + + + trans_init = transform + threshold = 0.02 # 0.1 # 0.02 + + n_points = 10000 + src = mesh_pred_new.sample_points_uniformly(number_of_points=n_points) + dst = mesh_gt.sample_points_uniformly(number_of_points=n_points) + + # reg_p2p = o3d.pipelines.registration.registration_icp(src_down, dst_down, threshold, trans_init, o3d.pipelines.registration.TransformationEstimationPointToPoint(), o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000)) + reg_p2p = o3d.pipelines.registration.registration_icp(src, dst, threshold, trans_init, o3d.pipelines.registration.TransformationEstimationPointToPoint(), o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000)) + + # mesh_pred_transf_refined = copy.deepcopy(mesh_pred_new).transform(reg_p2p.transformation) + # out_path_pred_transf_refined = save_imgs_path + '/' + prefix + 'alignment_final_' + img_name + '.obj' + # o3d.io.write_triangle_mesh(out_path_pred_transf_refined, mesh_pred_transf_refined) + + + aligned_mesh_final = trimesh.Trimesh(mesh_pred_new.vertices, mesh_pred_new.triangles, vertex_colors=[0, 255, 0]) + gt_mesh = trimesh.Trimesh(mesh_gt.vertices, mesh_gt.triangles, vertex_colors=[255, 0, 0]) + scene = trimesh.Scene([aligned_mesh_final, gt_mesh]) + out_path_alignment_with_gt = save_imgs_path + '/' + prefix + 'alignment_with_gt_' + img_name + '.obj' + + scene.export(out_path_alignment_with_gt) + ''' + + # import pdb; pdb.set_trace() + + + # SMAL_KEYPOINT_NAMES_FOR_3D_EVAL # 17 keypoints + # prepare target + target_keyp_isvalid = target_dict['keypoints_3d'][ind_img, :, 3].detach().cpu().numpy() + keyp_to_use = (np.asarray(SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL)==1)*(target_keyp_isvalid==1) + target_keyp_raw = target_dict['keypoints_3d'][ind_img, :, :3].detach().cpu().numpy() + target_keypoints = target_keyp_raw[keyp_to_use, :] + target_pointcloud = target_dict['pointcloud_points'][ind_img, :, :].detach().cpu().numpy() + # prepare prediction + pred_keypoints_raw = output_ref['vertices_smal'][ind_img, SMAL_KEYPOINT_INDICES_FOR_3D_EVAL, :].detach().cpu().numpy() + pred_keypoints = pred_keypoints_raw[keyp_to_use, :] + pred_pointcloud = verts_remeshed[ind_img, :, :].detach().cpu().numpy() + + + + + ''' + pred_keypoints_transf, pred_pointcloud_transf, procrustes_params = compute_similarity_transform(pred_keypoints, target_keypoints, num_joints=None, verts=pred_pointcloud) + pa_error = np.sqrt(np.sum((target_keypoints - pred_keypoints_transf) ** 2, axis=1)) + error_procrustes = np.mean(pa_error) + + + col_target = np.zeros((target_pointcloud.shape[0], 3), dtype=np.uint8) + col_target[:, 0] = 255 + col_pred = np.zeros((pred_pointcloud_transf.shape[0], 3), dtype=np.uint8) + col_pred[:, 1] = 255 + pc = trimesh.points.PointCloud(np.concatenate((target_pointcloud, pred_pointcloud_transf)), colors=np.concatenate((col_target, col_pred))) + out_path_pc = save_imgs_path + '/' + prefix + 'pointclouds_aligned_' + img_name + '.obj' + pc.export(out_path_pc) + + print(target_dict['mesh_path'][ind_img]) + print(error_procrustes) + file_alignment_errors.write(target_dict['mesh_path'][ind_img] + '\n') + file_alignment_errors.write('error: ' + str(error_procrustes) + ' \n') + + writer.writerow({'name': (target_dict['mesh_path'][ind_img]).split('/')[-1], 'error': str(error_procrustes)}) + + # import pdb; pdb.set_trace() + # alignment_dict = calculate_alignemnt_errors(output_ref['vertices_smal'][ind_img, :, :], target_dict['keypoints_3d'][ind_img, :, :], target_dict['pointcloud_points'][ind_img, :, :]) + # file_alignment_errors.write('error: ' + str(alignment_dict['error_procrustes']) + ' \n') + ''' + + + + + + + if index == 0: + if len_dataset is None: + len_data = val_loader.batch_size * len(val_loader) # 1703 + else: + len_data = len_dataset + if result_network == 'normal': + summaries = {'normal': dict(), 'ref': dict()} + summary = summaries['normal'] + else: + summary = summaries['ref'] + summary['pck'] = np.zeros((len_data)) + summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS} + summary['acc_sil_2d'] = np.zeros(len_data) + summary['betas'] = np.zeros((len_data,betas.shape[1])) + summary['betas_limbs'] = np.zeros((len_data, betas_limbs.shape[1])) + summary['z'] = np.zeros((len_data, zz.shape[1])) + summary['pose_rotmat'] = np.zeros((len_data, pose_rotmat.shape[1], 3, 3)) + summary['flength'] = np.zeros((len_data, flength.shape[1])) + summary['trans'] = np.zeros((len_data, trans.shape[1])) + summary['breed_indices'] = np.zeros((len_data)) + summary['image_names'] = [] # len_data * [None] + # ['vertices_smal'] = np.zeros((len_data, vertices_smal.shape[1], 3)) + else: + if result_network == 'normal': + summary = summaries['normal'] + else: + summary = summaries['ref'] + + + # import pdb; pdb.set_trace() + + + eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=render_all) + + + preds = eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh=None, skip_pck_and_iou=True) + # add results for all images in this batch to lists + curr_batch_size = pred_keyp.shape[0] + eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=True) + + # summary['vertices_smal'][my_step * batch_size:my_step * batch_size + curr_batch_size] = vertices_smal.detach().cpu().numpy() + + + + + + + + + + + + + + + + ''' + try: + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + partial_results['img_name'] = img_name + visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'], + focal_lengths=output_unnorm['flength'], + color=0) # 2) + # save image with predicted keypoints + pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) + pred_unp_maxval = output['keypoints_scores'][ind_img, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + if save_imgs_path is not None: + out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save predicted 3d model + # (1) front view + pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + partial_results['tex_pred'] = pred_tex + if save_imgs_path is not None: + out_path = save_imgs_path + '/tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + partial_results['comp_pred'] = im_masked + if save_imgs_path is not None: + out_path = save_imgs_path + '/comp_pred_' + img_name + '.png' + plt.imsave(out_path, im_masked) + # (2) side view + vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :] + roll = np.pi / 2 * torch.ones(1).float().to(device) + pitch = np.pi / 2 * torch.ones(1).float().to(device) + tensor_0 = torch.zeros(1).float().to(device) + tensor_1 = torch.ones(1).float().to(device) + RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) + RY = torch.stack([ + torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + torch.stack([tensor_0, tensor_1, tensor_0]), + torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) + vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3)) + vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 + visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, + focal_lengths=output_unnorm['flength'], + color=0) # 2) + pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + partial_results['rot_tex_pred'] = pred_tex + if save_imgs_path is not None: + out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + render_all = True + if render_all: + # save input image + inp_img = input[ind_img, :, :, :].detach().clone() + if save_imgs_path is not None: + out_path = save_imgs_path + '/image_' + img_name + '.png' + save_input_image(inp_img, out_path) + # save posed mesh + V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy() + Faces = model.smal.f + mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True) + partial_results['mesh_posed'] = mesh_posed + if save_imgs_path is not None: + mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj') + except: + print('pass...') + all_results.append(partial_results) + ''' + + my_step += 1 + + + file_alignment_errors.close() + csv_file_alignment_errors.close() + + + if return_results: + return all_results + else: + return summaries \ No newline at end of file diff --git a/src/combined_model/train_main_image_to_3d_withbreedrel.py b/src/combined_model/train_main_image_to_3d_withbreedrel.py new file mode 100644 index 0000000000000000000000000000000000000000..59308fa38badbbce05250edd44b8435a1896838d --- /dev/null +++ b/src/combined_model/train_main_image_to_3d_withbreedrel.py @@ -0,0 +1,496 @@ + +import torch +import torch.nn as nn +import torch.backends.cudnn +import torch.nn.parallel +from tqdm import tqdm +import os +import pathlib +from matplotlib import pyplot as plt +import cv2 +import numpy as np +import torch +import trimesh + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image +from metrics.metrics import Metrics +from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_training_epoch(train_loader, model, loss_module, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None): + losses = AverageMeter() + losses_keyp = AverageMeter() + losses_silh = AverageMeter() + losses_shape = AverageMeter() + losses_pose = AverageMeter() + losses_class = AverageMeter() + losses_breed = AverageMeter() + losses_partseg = AverageMeter() + accuracies = AverageMeter() + # Put the model in training mode. + model.train() + # prepare progress bar + iterable = enumerate(train_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False) + iterable = progress + # information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + # prepare variables, put them on the right device + for i, (input, target_dict) in iterable: + batch_size = input.shape[0] + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key in ['has_seg', 'gc']: + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # ----------------------- do training step ----------------------- + assert model.training, 'model must be in training mode.' + with torch.enable_grad(): + # ----- forward pass ----- + output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict) + # ----- loss ----- + loss, loss_dict = loss_module(output_reproj=output_reproj, + target_dict=target_dict, + weight_dict=weight_dict) + # ----- backward pass and parameter update ----- + optimiser.zero_grad() + loss.backward() + optimiser.step() + # ---------------------------------------------------------------- + + # prepare losses for progress bar + bs_fake = 1 # batch_size + losses.update(loss_dict['loss'], bs_fake) + losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake) + losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake) + losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake) + losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake) + losses_class.update(loss_dict['loss_class_weighted'], bs_fake) + losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake) + losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake) + acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model' + accuracies.update(acc, bs_fake) + # Show losses as part of the progress bar. + if progress is not None: + my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_shape=losses_shape.avg, + loss_pose=losses_pose.avg, + loss_class=losses_class.avg, + loss_breed=losses_breed.avg, + loss_partseg=losses_partseg.avg + ) + progress.set_postfix_str(my_string) + + return my_string, accuracies.avg + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_validation_epoch(val_loader, model, loss_module, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None): + losses = AverageMeter() + losses_keyp = AverageMeter() + losses_silh = AverageMeter() + losses_shape = AverageMeter() + losses_pose = AverageMeter() + losses_class = AverageMeter() + losses_breed = AverageMeter() + losses_partseg = AverageMeter() + accuracies = AverageMeter() + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + # Put the model in evaluation mode. + model.eval() + # prepare progress bar + iterable = enumerate(val_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False) + iterable = progress + # summarize information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + batch_size = val_loader.batch_size + # prepare variables, put them on the right device + my_step = 0 + for i, (input, target_dict) in iterable: + curr_batch_size = input.shape[0] + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key in ['has_seg', 'gc']: + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + # ----------------------- do validation step ----------------------- + with torch.no_grad(): + # ----- forward pass ----- + # output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores']) + # output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints']) + # output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength']) + # target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh']) + output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict) + # ----- loss ----- + if metrics == 'no_loss': + loss, loss_dict = loss_module(output_reproj=output_reproj, + target_dict=target_dict, + weight_dict=weight_dict) + # ---------------------------------------------------------------- + + if i == 0: + if len_dataset is None: + len_data = val_loader.batch_size * len(val_loader) # 1703 + else: + len_data = len_dataset + if metrics == 'all' or metrics == 'no_loss': + pck = np.zeros((len_data)) + pck_by_part = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS} + acc_sil_2d = np.zeros(len_data) + + all_betas = np.zeros((len_data, output_reproj['betas'].shape[1])) + all_betas_limbs = np.zeros((len_data, output_reproj['betas_limbs'].shape[1])) + all_z = np.zeros((len_data, output_reproj['z'].shape[1])) + all_pose_rotmat = np.zeros((len_data, output_unnorm['pose_rotmat'].shape[1], 3, 3)) + all_flength = np.zeros((len_data, output_unnorm['flength'].shape[1])) + all_trans = np.zeros((len_data, output_unnorm['trans'].shape[1])) + all_breed_indices = np.zeros((len_data)) + all_image_names = [] # len_data * [None] + + index = i + ind_img = 0 + if save_imgs_path is not None: + # render predicted 3d models + visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'], + focal_lengths=output_unnorm['flength'], + color=0) # color=2) + for ind_img in range(len(target_dict['index'])): + try: + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + # save image with predicted keypoints + out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' + pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) + pred_unp_maxval = output['keypoints_scores'][ind_img, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save predicted 3d model (front view) + pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + out_path = save_imgs_path + '/tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + out_path = save_imgs_path + '/comp_pred_' + img_name + '.png' + plt.imsave(out_path, im_masked) + # save predicted 3d model (side view) + vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :] + roll = np.pi / 2 * torch.ones(1).float().to(device) + pitch = np.pi / 2 * torch.ones(1).float().to(device) + tensor_0 = torch.zeros(1).float().to(device) + tensor_1 = torch.ones(1).float().to(device) + RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) + RY = torch.stack([ + torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + torch.stack([tensor_0, tensor_1, tensor_0]), + torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) + vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3)) + vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 + + visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, + focal_lengths=output_unnorm['flength'], + color=0) # 2) + pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + if render_all: + # save input image + inp_img = input[ind_img, :, :, :].detach().clone() + out_path = save_imgs_path + '/image_' + img_name + '.png' + save_input_image(inp_img, out_path) + # save mesh + V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy() + Faces = model.smal.f + mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True) + mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj') + except: + print('dont save an image') + + if metrics == 'all' or metrics == 'no_loss': + # prepare a dictionary with all the predicted results + preds = {} + preds['betas'] = output_reproj['betas'].cpu().detach().numpy() + preds['betas_limbs'] = output_reproj['betas_limbs'].cpu().detach().numpy() + preds['z'] = output_reproj['z'].cpu().detach().numpy() + preds['pose_rotmat'] = output_unnorm['pose_rotmat'].cpu().detach().numpy() + preds['flength'] = output_unnorm['flength'].cpu().detach().numpy() + preds['trans'] = output_unnorm['trans'].cpu().detach().numpy() + preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1)) + img_names = [] + for ind_img2 in range(0, output_reproj['betas'].shape[0]): + if test_name_list is not None: + img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_') + img_name2 = img_name2.split('.')[0] + else: + img_name2 = str(index) + '_' + str(ind_img2) + img_names.append(img_name2) + preds['image_names'] = img_names + # prepare keypoints for PCK calculation - predicted as well as ground truth + pred_keypoints_norm = output['keypoints_norm'] # -1 to 1 + pred_keypoints_256 = output_reproj['keyp_2d'] + pred_keypoints = pred_keypoints_256 + gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1) + gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1 + gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm + # prepare silhouette for IoU calculation - predicted as well as ground truth + has_seg = target_dict['has_seg'] + img_border_mask = target_dict['img_border_mask'][:, 0, :, :] + gtseg = target_dict['silh'] + synth_silhouettes = output_reproj['silh'][:, 0, :, :] # output_reproj['silh'] + synth_silhouettes[synth_silhouettes>0.5] = 1 + synth_silhouettes[synth_silhouettes<0.5] = 0 + # calculate PCK as well as IoU (similar to WLDO) + preds['acc_PCK'] = Metrics.PCK( + pred_keypoints, gt_keypoints, + gtseg, has_seg, idxs=EVAL_KEYPOINTS, + thresh_range=[pck_thresh], # [0.15], + ) + preds['acc_IOU'] = Metrics.IOU( + synth_silhouettes, gtseg, + img_border_mask, mask=has_seg + ) + for group, group_kps in KEYPOINT_GROUPS.items(): + preds[f'{group}_PCK'] = Metrics.PCK( + pred_keypoints, gt_keypoints, gtseg, has_seg, + thresh_range=[pck_thresh], # [0.15], + idxs=group_kps + ) + # add results for all images in this batch to lists + curr_batch_size = pred_keypoints_256.shape[0] + if not (preds['acc_PCK'].data.cpu().numpy().shape == (pck[my_step * batch_size:my_step * batch_size + curr_batch_size]).shape): + import pdb; pdb.set_trace() + pck[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() + acc_sil_2d[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() + for part in pck_by_part: + pck_by_part[part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() + all_betas[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas'] + all_betas_limbs[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs'] + all_z[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z'] + all_pose_rotmat[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat'] + all_flength[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength'] + all_trans[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans'] + all_breed_indices[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index'] + all_image_names.extend(preds['image_names']) + # update progress bar + if progress is not None: + my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format( + pck[:(my_step * batch_size + curr_batch_size)].mean(), + acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean()) + progress.set_postfix_str(my_string) + else: + # measure accuracy and record loss + bs_fake = 1 # batch_size + losses.update(loss_dict['loss'], bs_fake) + losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake) + losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake) + losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake) + losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake) + losses_class.update(loss_dict['loss_class_weighted'], bs_fake) + losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake) + losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake) + acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model' + accuracies.update(acc, bs_fake) + # Show losses as part of the progress bar. + if progress is not None: + my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format( + loss=losses.avg, + loss_keyp=losses_keyp.avg, + loss_silh=losses_silh.avg, + loss_shape=losses_shape.avg, + loss_pose=losses_pose.avg, + loss_class=losses_class.avg, + loss_breed=losses_breed.avg, + loss_partseg=losses_partseg.avg + ) + progress.set_postfix_str(my_string) + my_step += 1 + if metrics == 'all': + summary = {'pck': pck, 'acc_sil_2d': acc_sil_2d, 'pck_by_part':pck_by_part, + 'betas': all_betas, 'betas_limbs': all_betas_limbs, 'z': all_z, 'pose_rotmat': all_pose_rotmat, + 'flenght': all_flength, 'trans': all_trans, 'image_names': all_image_names, 'breed_indices': all_breed_indices} + return my_string, summary + elif metrics == 'no_loss': + return my_string, np.average(np.asarray(acc_sil_2d)) + else: + return my_string, accuracies.avg + + +# --------------------------------------------------------------------------------------------------------------------------- +def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False): + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + all_results = [] + + # Put the model in evaluation mode. + model.eval() + + iterable = enumerate(val_loader) + + # information for normalization + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + + ''' + return_mesh_with_gt_groundplane = True + if return_mesh_with_gt_groundplane: + remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl' + with open(remeshing_path, 'rb') as fp: + remeshing_dict = pkl.load(fp) + remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device) + remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device) + + # from smal_pytorch.smal_model.smal_torch_new import SMAL + print('start: load smal default model (barc), but only for vertices') + smal = SMAL() + print('end: load smal default model (barc), but only for vertices') + smal_template_verts = smal.v_template.detach().cpu().numpy() + smal_faces = smal.faces.detach().cpu().numpy() + + file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.txt', 'a') # append mode + file_alignment_errors.write(" ----------- start evaluation ------------- \n ") + + csv_file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.csv', 'w') # write mode + fieldnames = ['name', 'error'] + writer = csv.DictWriter(csv_file_alignment_errors, fieldnames=fieldnames) + writer.writeheader() + ''' + + my_step = 0 + for i, (input, target_dict) in iterable: + batch_size = input.shape[0] + input = input.float().to(device) + partial_results = {} + + # ----------------------- do visualization step ----------------------- + with torch.no_grad(): + output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict) + + index = i + ind_img = 0 + for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size): + + try: + if test_name_list is not None: + img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') + img_name = img_name.split('.')[0] + else: + img_name = str(index) + '_' + str(ind_img) + partial_results['img_name'] = img_name + visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'], + focal_lengths=output_unnorm['flength'], + color=0) # 2) + # save image with predicted keypoints + pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) + pred_unp_maxval = output['keypoints_scores'][ind_img, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind_img, :, :, :].detach().clone() + if save_imgs_path is not None: + out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 + # save predicted 3d model + # (1) front view + pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + partial_results['tex_pred'] = pred_tex + if save_imgs_path is not None: + out_path = save_imgs_path + '/tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + input_image = input[ind_img, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) + im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] + partial_results['comp_pred'] = im_masked + if save_imgs_path is not None: + out_path = save_imgs_path + '/comp_pred_' + img_name + '.png' + plt.imsave(out_path, im_masked) + # (2) side view + vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :] + roll = np.pi / 2 * torch.ones(1).float().to(device) + pitch = np.pi / 2 * torch.ones(1).float().to(device) + tensor_0 = torch.zeros(1).float().to(device) + tensor_1 = torch.ones(1).float().to(device) + RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) + RY = torch.stack([ + torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), + torch.stack([tensor_0, tensor_1, tensor_0]), + torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) + vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3)) + vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 + visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, + focal_lengths=output_unnorm['flength'], + color=0) # 2) + pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 + pred_tex_max = np.max(pred_tex, axis=2) + partial_results['rot_tex_pred'] = pred_tex + if save_imgs_path is not None: + out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png' + plt.imsave(out_path, pred_tex) + render_all = True + if render_all: + # save input image + inp_img = input[ind_img, :, :, :].detach().clone() + if save_imgs_path is not None: + out_path = save_imgs_path + '/image_' + img_name + '.png' + save_input_image(inp_img, out_path) + # save posed mesh + V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy() + Faces = model.smal.f + mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True) + partial_results['mesh_posed'] = mesh_posed + if save_imgs_path is not None: + mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj') + except: + print('pass...') + all_results.append(partial_results) + if return_results: + return all_results + else: + return \ No newline at end of file diff --git a/src/configs/SMAL_configs.py b/src/configs/SMAL_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..3800dd10bcb5065701b6523a5647a0b72a18a2dd --- /dev/null +++ b/src/configs/SMAL_configs.py @@ -0,0 +1,230 @@ + + +import numpy as np +import os +import sys + + +# SMAL_DATA_DIR = '/is/cluster/work/nrueegg/dog_project/pytorch-dogs-inference/src/smal_pytorch/smpl_models/' +# SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'smal_pytorch', 'smal_data') +SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'smal_data') + +# we replace the old SMAL model by a more dog specific model (see BARC cvpr 2022 paper) +# our model has several differences compared to the original SMAL model, some of them are: +# - the PCA shape space is recalculated (from partially new data and weighted) +# - coefficients for limb length changes are allowed (similar to WLDO, we did borrow some of their code) +# - all dogs have a core of approximately the same length +# - dogs are centered in their root joint (which is close to the tail base) +# -> like this the root rotations is always around this joint AND (0, 0, 0) +# -> before this it would happen that the animal 'slips' from the image middle to the side when rotating it. Now +# 'trans' also defines the center of the rotation +# - we correct the back joint locations such that all those joints are more aligned + +# logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] +# logscale_part_list = ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'] + +SMAL_MODEL_CONFIG = { + 'barc': { + 'smal_model_type': 'barc', + 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'my_smpl_SMBLD_nbj_v3.pkl'), + 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl'), + 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl'), + 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'], + }, + '39dogs_diffsize': { + 'smal_model_type': '39dogs_diffsize', + 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_00791_nadine_Jr_4_dog.pkl'), + 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_00791_nadine_Jr_4_dog.pkl'), + 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_00791_nadine_Jr_4_dog.pkl'), + 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'], + }, + '39dogs_norm': { + 'smal_model_type': '39dogs_norm', + 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_Jr_4_dog.pkl'), + 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'), + 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'), + 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'], + }, + '39dogs_norm_9ll': { # 9 limb length parameters + 'smal_model_type': '39dogs_norm_9ll', + 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_Jr_4_dog.pkl'), + 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'), + 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'), + 'logscale_part_list': ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'], + }, + '39dogs_norm_newv2': { # front and back legs of equal lengths + 'smal_model_type': '39dogs_norm_newv2', + 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_newv2_dog.pkl'), + 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv2_dog.pkl'), + 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv2_dog.pkl'), + 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'], + }, + '39dogs_norm_newv3': { # pca on dame AND different front and back legs lengths + 'smal_model_type': '39dogs_norm_newv3', + 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_newv3_dog.pkl'), + 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv3_dog.pkl'), + 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv3_dog.pkl'), + 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'], + }, +} + + +SYMMETRY_INDS_FILE = os.path.join(SMAL_DATA_DIR, 'symmetry_inds.json') + +mean_dog_bone_lengths_txt = os.path.join(SMAL_DATA_DIR, 'mean_dog_bone_lengths.txt') + +# some vertex indices, (from silvia zuffi´s code, create_projected_images_cats.py) +KEY_VIDS = np.array(([1068, 1080, 1029, 1226], # left eye + [2660, 3030, 2675, 3038], # right eye + [910], # mouth low + [360, 1203, 1235, 1230], # front left leg, low + [3188, 3156, 2327, 3183], # front right leg, low + [1976, 1974, 1980, 856], # back left leg, low + [3854, 2820, 3852, 3858], # back right leg, low + [452, 1811], # tail start + [416, 235, 182], # front left leg, top + [2156, 2382, 2203], # front right leg, top + [829], # back left leg, top + [2793], # back right leg, top + [60, 114, 186, 59], # throat, close to base of neck + [2091, 2037, 2036, 2160], # withers (a bit lower than in reality) + [384, 799, 1169, 431], # front left leg, middle + [2351, 2763, 2397, 3127], # front right leg, middle + [221, 104], # back left leg, middle + [2754, 2192], # back right leg, middle + [191, 1158, 3116, 2165], # neck + [28], # Tail tip + [542], # Left Ear + [2507], # Right Ear + [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip + [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail + +# the following vertices are used for visibility only: if one of the vertices is visible, +# then we assume that the joint is visible! There is some noise, but we don't care, as this is +# for generation of the synthetic dataset only +KEY_VIDS_VISIBILITY_ONLY = np.array(([1068, 1080, 1029, 1226, 645], # left eye + [2660, 3030, 2675, 3038, 2567], # right eye + [910, 11, 5], # mouth low + [360, 1203, 1235, 1230, 298, 408, 303, 293, 384], # front left leg, low + [3188, 3156, 2327, 3183, 2261, 2271, 2573, 2265], # front right leg, low + [1976, 1974, 1980, 856, 559, 851, 556], # back left leg, low + [3854, 2820, 3852, 3858, 2524, 2522, 2815, 2072], # back right leg, low + [452, 1811, 63, 194, 52, 370, 64], # tail start + [416, 235, 182, 440, 8, 80, 73, 112], # front left leg, top + [2156, 2382, 2203, 2050, 2052, 2406, 3], # front right leg, top + [829, 219, 218, 173, 17, 7, 279], # back left leg, top + [2793, 582, 140, 87, 2188, 2147, 2063], # back right leg, top + [60, 114, 186, 59, 878, 130, 189, 45], # throat, close to base of neck + [2091, 2037, 2036, 2160, 190, 2164], # withers (a bit lower than in reality) + [384, 799, 1169, 431, 321, 314, 437, 310, 323], # front left leg, middle + [2351, 2763, 2397, 3127, 2278, 2285, 2282, 2275, 2359], # front right leg, middle + [221, 104, 105, 97, 103], # back left leg, middle + [2754, 2192, 2080, 2251, 2075, 2074], # back right leg, middle + [191, 1158, 3116, 2165, 154, 653, 133, 339], # neck + [28, 474, 475, 731, 24], # Tail tip + [542, 147, 509, 200, 522], # Left Ear + [2507,2174, 2122, 2126, 2474], # Right Ear + [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip + [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail + +# Keypoint indices for 3d sketchfab evaluation +SMAL_KEYPOINT_NAMES_FOR_3D_EVAL = ['right_front_paw','right_front_elbow','right_back_paw','right_back_hock','right_ear_top','right_ear_bottom','right_eye', \ + 'left_front_paw','left_front_elbow','left_back_paw','left_back_hock','left_ear_top','left_ear_bottom','left_eye', \ + 'nose','tail_start','tail_end'] +SMAL_KEYPOINT_INDICES_FOR_3D_EVAL = [2577, 2361, 2820, 2085, 2125, 2453, 2668, 613, 394, 855, 786, 149, 486, 1079, 1845, 1820, 28] +SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL = [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0] # [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0] + + + + +# see: https://github.com/benjiebob/SMALify/blob/master/config.py +# JOINT DEFINITIONS - based on SMAL joints and additional {eyes, ear tips, chin and nose} +TORSO_JOINTS = [2, 5, 8, 11, 12, 23] +CANONICAL_MODEL_JOINTS = [ + 10, 9, 8, # upper_left [paw, middle, top] + 20, 19, 18, # lower_left [paw, middle, top] + 14, 13, 12, # upper_right [paw, middle, top] + 24, 23, 22, # lower_right [paw, middle, top] + 25, 31, # tail [start, end] + 33, 34, # ear base [left, right] + 35, 36, # nose, chin + 38, 37, # ear tip [left, right] + 39, 40, # eyes [left, right] + 6, 11, # withers, throat (throat is inaccurate and withers also) + 28] # tail middle + # old: 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat + +CANONICAL_MODEL_JOINTS_REFINED = [ + 41, 9, 8, # upper_left [paw, middle, top] + 43, 19, 18, # lower_left [paw, middle, top] + 42, 13, 12, # upper_right [paw, middle, top] + 44, 23, 22, # lower_right [paw, middle, top] + 25, 31, # tail [start, end] + 33, 34, # ear base [left, right] + 35, 36, # nose, chin + 38, 37, # ear tip [left, right] + 39, 40, # eyes [left, right] + 46, 45, # withers, throat + 28] # tail middle + +# the following list gives the indices of the KEY_VIDS_JOINTS that must be taken in order +# to judge if the CANONICAL_MODEL_JOINTS are visible - those are all approximations! +CMJ_VISIBILITY_IN_KEY_VIDS = [ + 3, 14, 8, # left front leg + 5, 16, 10, # left rear leg + 4, 15, 9, # right front leg + 6, 17, 11, # right rear leg + 7, 19, # tail front, tail back + 20, 21, # ear base (but can not be found in blue, se we take the tip) + 2, 2, # mouth (was: 22, 2) + 20, 21, # ear tips + 1, 0, # eyes + 18, # withers, not sure where this point is + 12, # throat + 23, # mid tail + ] + +# define which bone lengths are used as input to the 2d-to-3d network +IDXS_BONES_NO_REDUNDANCY = [6,7,8,9,16,17,18,19,32,1,2,3,4,5,14,15,24,25,26,27,28,29,30,31] +# load bone lengths of the mean dog (already filtered) +mean_dog_bone_lengths = [] +with open(mean_dog_bone_lengths_txt, 'r') as f: + for line in f: + mean_dog_bone_lengths.append(float(line.split('\n')[0])) +MEAN_DOG_BONE_LENGTHS_NO_RED = np.asarray(mean_dog_bone_lengths)[IDXS_BONES_NO_REDUNDANCY] # (24, ) + +# Body part segmentation: +# the body can be segmented based on the bones and for the new dog model also based on the new shapedirs +# axis_horizontal = self.shapedirs[2, :].reshape((-1, 3))[:, 0] +# all_indices = np.arange(3889) +# tail_indices = all_indices[axis_horizontal.detach().cpu().numpy() < 0.0] +VERTEX_IDS_TAIL = [ 0, 4, 9, 10, 24, 25, 28, 453, 454, 456, 457, + 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, + 469, 470, 471, 472, 473, 474, 475, 724, 725, 726, 727, + 728, 729, 730, 731, 813, 975, 976, 977, 1109, 1110, 1111, + 1811, 1813, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, + 1828, 1835, 1836, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, + 1968, 1969, 2418, 2419, 2421, 2422, 2423, 2424, 2425, 2426, 2427, + 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, + 2439, 2440, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2777, + 3067, 3068, 3069, 3842, 3843, 3844, 3845, 3846, 3847] + +# same as in https://github.com/benjiebob/WLDO/blob/master/global_utils/config.py +EVAL_KEYPOINTS = [ + 0, 1, 2, # left front + 3, 4, 5, # left rear + 6, 7, 8, # right front + 9, 10, 11, # right rear + 12, 13, # tail start -> end + 14, 15, # left ear, right ear + 16, 17, # nose, chin + 18, 19] # left tip, right tip + +KEYPOINT_GROUPS = { + 'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs + 'tail': [12, 13], # tail + 'ears': [14, 15, 18, 19], # ears + 'face': [16, 17] # face +} + + diff --git a/src/configs/anipose_data_info.py b/src/configs/anipose_data_info.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7bad68b45cf9926fdfd3ca1b7e1f147e909cfd --- /dev/null +++ b/src/configs/anipose_data_info.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import List +import json +import numpy as np +import os + +STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics') +STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json') + +@dataclass +class DataInfo: + rgb_mean: List[float] + rgb_stddev: List[float] + joint_names: List[str] + hflip_indices: List[int] + n_joints: int + n_keyp: int + n_bones: int + n_betas: int + image_size: int + trans_mean: np.ndarray + trans_std: np.ndarray + flength_mean: np.ndarray + flength_std: np.ndarray + pose_rot6d_mean: np.ndarray + keypoint_weights: List[float] + +# SMAL samples 3d statistics +# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore +def load_statistics(statistics_path): + with open(statistics_path) as f: + statistics = json.load(f) + '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']] + statistics['pose_mean'] = new_pose_mean + j_out = json.dumps(statistics, indent=4) #, sort_keys=True) + with open(self.statistics_path, 'w') as file: file.write(j_out)''' + new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']), + 'trans_std': np.asarray(statistics['trans_std']), + 'flength_mean': np.asarray(statistics['flength_mean']), + 'flength_std': np.asarray(statistics['flength_std']), + 'pose_mean': np.asarray(statistics['pose_mean']), + } + new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6)) + return new_statistics +STATISTICS = load_statistics(STATISTICS_PATH) + +AniPose_JOINT_NAMES_swapped = [ + 'L_F_Paw', 'L_F_Knee', 'L_F_Elbow', + 'L_B_Paw', 'L_B_Knee', 'L_B_Elbow', + 'R_F_Paw', 'R_F_Knee', 'R_F_Elbow', + 'R_B_Paw', 'R_B_Knee', 'R_B_Elbow', + 'TailBase', '_Tail_end_', 'L_EarBase', 'R_EarBase', + 'Nose', '_Chin_', '_Left_ear_tip_', '_Right_ear_tip_', + 'L_Eye', 'R_Eye', 'Withers', 'Throat'] + +KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2] + +COMPLETE_DATA_INFO = DataInfo( + rgb_mean=[0.4404, 0.4440, 0.4327], # not sure + rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure + joint_names=AniPose_JOINT_NAMES_swapped, # AniPose_JOINT_NAMES, + hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23], + n_joints = 35, + n_keyp = 24, # 20, # 25, + n_bones = 24, + n_betas = 30, # 10, + image_size = 256, + trans_mean = STATISTICS['trans_mean'], + trans_std = STATISTICS['trans_std'], + flength_mean = STATISTICS['flength_mean'], + flength_std = STATISTICS['flength_std'], + pose_rot6d_mean = STATISTICS['pose_rot6d_mean'], + keypoint_weights = KEYPOINT_WEIGHTS + ) diff --git a/src/configs/barc_cfg_defaults.py b/src/configs/barc_cfg_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..10ecdbbe063544b34933ecb3c1f131ad19e7a3cc --- /dev/null +++ b/src/configs/barc_cfg_defaults.py @@ -0,0 +1,121 @@ + +from yacs.config import CfgNode as CN +import argparse +import yaml +import os + +abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',)) + +_C = CN() +_C.barc_dir = abs_barc_dir +_C.device = 'cuda' + +## path settings +_C.paths = CN() +_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/' +_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/' +_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt' + +## parameter settings +_C.params = CN() +_C.params.ARCH = 'hg8' +_C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae' +_C.params.NF_VERSION = 3 +_C.params.N_JOINTS = 35 +_C.params.N_KEYP = 24 #20 +_C.params.N_SEG = 2 +_C.params.N_PARTSEG = 15 +_C.params.UPSAMPLE_SEG = True +_C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt) +_C.params.N_BETAS = 30 # 10 +_C.params.N_BETAS_LIMBS = 7 +_C.params.N_BONES = 24 +_C.params.N_BREEDS = 121 # 120 breeds plus background +_C.params.IMG_SIZE = 256 +_C.params.SILH_NO_TAIL = False +_C.params.KP_THRESHOLD = None +_C.params.ADD_Z_TO_3D_INPUT = False +_C.params.N_SEGBPS = 64*2 +_C.params.ADD_SEGBPS_TO_3D_INPUT = True +_C.params.FIX_FLENGTH = False +_C.params.RENDER_ALL = True +_C.params.VLIN = 2 +_C.params.STRUCTURE_Z_TO_B = 'lin' +_C.params.N_Z_FREE = 64 +_C.params.PCK_THRESH = 0.15 +_C.params.REF_NET_TYPE = 'add' # refinement network type +_C.params.REF_DETACH_SHAPE = True +_C.params.GRAPHCNN_TYPE = 'inexistent' +_C.params.ISFLAT_TYPE = 'inexistent' +_C.params.SHAPEREF_TYPE = 'inexistent' + +## SMAL settings +_C.smal = CN() +_C.smal.SMAL_MODEL_TYPE = 'barc' +_C.smal.SMAL_KEYP_CONF = 'green' + +## optimization settings +_C.optim = CN() +_C.optim.LR = 5e-4 +_C.optim.SCHEDULE = [150, 175, 200] +_C.optim.GAMMA = 0.1 +_C.optim.MOMENTUM = 0 +_C.optim.WEIGHT_DECAY = 0 +_C.optim.EPOCHS = 220 +_C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler) +_C.optim.TRAIN_PARTS = 'all_without_shapedirs' + +## dataset settings +_C.data = CN() +_C.data.DATASET = 'stanext24' +_C.data.V12 = True +_C.data.SHORTEN_VAL_DATASET_TO = None +_C.data.VAL_OPT = 'val' +_C.data.VAL_METRICS = 'no_loss' + +# --------------------------------------- +def update_dependent_vars(cfg): + cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG + if cfg.params.VLIN == 0: + cfg.params.NUM_STAGE_COMB = 2 + cfg.params.NUM_STAGE_HEADS = 1 + cfg.params.NUM_STAGE_HEADS_POSE = 1 + cfg.params.TRANS_SEP = False + elif cfg.params.VLIN == 1: + cfg.params.NUM_STAGE_COMB = 3 + cfg.params.NUM_STAGE_HEADS = 1 + cfg.params.NUM_STAGE_HEADS_POSE = 2 + cfg.params.TRANS_SEP = False + elif cfg.params.VLIN == 2: + cfg.params.NUM_STAGE_COMB = 3 + cfg.params.NUM_STAGE_HEADS = 1 + cfg.params.NUM_STAGE_HEADS_POSE = 2 + cfg.params.TRANS_SEP = True + else: + raise NotImplementedError + if cfg.params.STRUCTURE_Z_TO_B == '1dconv': + cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS + else: + cfg.params.N_Z = cfg.params.N_Z_FREE + return + + +update_dependent_vars(_C) +global _cfg_global +_cfg_global = _C.clone() + + +def get_cfg_defaults(): + # Get a yacs CfgNode object with default values as defined within this file. + # Return a clone so that the defaults will not be altered. + return _C.clone() + +def update_cfg_global_with_yaml(cfg_yaml_file): + _cfg_global.merge_from_file(cfg_yaml_file) + update_dependent_vars(_cfg_global) + return + +def get_cfg_global_updated(): + # return _cfg_global.clone() + return _cfg_global + diff --git a/src/configs/barc_cfg_train.yaml b/src/configs/barc_cfg_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8bcdeb411acf8e5e07b35fd234fb5b03125c270 --- /dev/null +++ b/src/configs/barc_cfg_train.yaml @@ -0,0 +1,24 @@ + +paths: + ROOT_OUT_PATH: './results/' + ROOT_CHECKPOINT_PATH: './checkpoint/' + MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt' + +smal: + SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_diffsize' # 'barc' + SMAL_KEYP_CONF: 'olive' # 'green' + +optim: + LR: 5e-4 + SCHEDULE: [150, 175, 200] + GAMMA: 0.1 + MOMENTUM: 0 + WEIGHT_DECAY: 0 + EPOCHS: 220 + BATCH_SIZE: 12 # keep 12 (needs to be an even number, as we have a custom data sampler) + TRAIN_PARTS: 'all_without_shapedirs' + +data: + DATASET: 'stanext24' + SHORTEN_VAL_DATASET_TO: 600 # this is faster as we do not evaluate on the whole validation set + VAL_OPT: 'val' \ No newline at end of file diff --git a/src/configs/barc_loss_weights_allzeros.json b/src/configs/barc_loss_weights_allzeros.json new file mode 100644 index 0000000000000000000000000000000000000000..a96a73d4a3de738e487ee67e9fd7d552ba9bd5f1 --- /dev/null +++ b/src/configs/barc_loss_weights_allzeros.json @@ -0,0 +1,30 @@ + + + +{ + "breed_options": [ + "4" + ], + "breed": 0.0, + "class": 0.0, + "models3d": 0.0, + "keyp": 0.0, + "silh": 0.0, + "shape_options": [ + "smal", + "limbs7" + ], + "shape": [ + 0, + 0 + ], + "poseprior_options": [ + "normalizing_flow_tiger_logprob" + ], + "poseprior": 0.0, + "poselegssidemovement": 0.0, + "flength": 0.0, + "partseg": 0, + "shapedirs": 0, + "pose_0": 0.0 +} \ No newline at end of file diff --git a/src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json b/src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json new file mode 100644 index 0000000000000000000000000000000000000000..48e0f6b921c424294fe549bc1adec4ac61864821 --- /dev/null +++ b/src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json @@ -0,0 +1,30 @@ + + + +{ + "breed_options": [ + "4" + ], + "breed": 5.0, + "class": 5.0, + "models3d": 0.1, + "keyp": 0.2, + "silh": 50.0, + "shape_options": [ + "smal", + "limbs7" + ], + "shape": [ + 0.1, + 1.0 + ], + "poseprior_options": [ + "normalizing_flow_tiger_logprob" + ], + "poseprior": 0.1, + "poselegssidemovement": 10.0, + "flength": 1.0, + "partseg": 0, + "shapedirs": 0, + "pose_0": 0.0 +} diff --git a/src/configs/data_info.py b/src/configs/data_info.py new file mode 100644 index 0000000000000000000000000000000000000000..cf28608e6361b089d49520e6bf03d142e1aab799 --- /dev/null +++ b/src/configs/data_info.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass +from typing import List +import json +import numpy as np +import os +import sys + +STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics') +STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json') + +@dataclass +class DataInfo: + rgb_mean: List[float] + rgb_stddev: List[float] + joint_names: List[str] + hflip_indices: List[int] + n_joints: int + n_keyp: int + n_bones: int + n_betas: int + image_size: int + trans_mean: np.ndarray + trans_std: np.ndarray + flength_mean: np.ndarray + flength_std: np.ndarray + pose_rot6d_mean: np.ndarray + keypoint_weights: List[float] + +# SMAL samples 3d statistics +# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore +def load_statistics(statistics_path): + with open(statistics_path) as f: + statistics = json.load(f) + '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']] + statistics['pose_mean'] = new_pose_mean + j_out = json.dumps(statistics, indent=4) #, sort_keys=True) + with open(self.statistics_path, 'w') as file: file.write(j_out)''' + new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']), + 'trans_std': np.asarray(statistics['trans_std']), + 'flength_mean': np.asarray(statistics['flength_mean']), + 'flength_std': np.asarray(statistics['flength_std']), + 'pose_mean': np.asarray(statistics['pose_mean']), + } + new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6)) + return new_statistics +STATISTICS = load_statistics(STATISTICS_PATH) + + +############################################################################ +# for StanExt (original number of keypoints, 20 not 24) + +# for keypoint names see: https://github.com/benjiebob/StanfordExtra/blob/master/keypoint_definitions.csv +StanExt_JOINT_NAMES = [ + 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top', + 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top', + 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top', + 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top', + 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear', + 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip'] + +KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2] + +COMPLETE_DATA_INFO = DataInfo( + rgb_mean=[0.4404, 0.4440, 0.4327], # not sure + rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure + joint_names=StanExt_JOINT_NAMES, + hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18], + n_joints = 35, + n_keyp = 20, # 25, + n_bones = 24, + n_betas = 30, # 10, + image_size = 256, + trans_mean = STATISTICS['trans_mean'], + trans_std = STATISTICS['trans_std'], + flength_mean = STATISTICS['flength_mean'], + flength_std = STATISTICS['flength_std'], + pose_rot6d_mean = STATISTICS['pose_rot6d_mean'], + keypoint_weights = KEYPOINT_WEIGHTS + ) + + +############################################################################ +# new for StanExt24 + +# ..., 'Left_eye', 'Right_eye', 'Withers', 'Throat'] # the last 4 keypoints are in the animal_pose dataset, but not StanfordExtra +StanExt_JOINT_NAMES_24 = [ + 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top', + 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top', + 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top', + 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top', + 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear', + 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip', + 'Left_eye', 'Right_eye', 'Withers', 'Throat'] + +KEYPOINT_WEIGHTS_24 = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2, 1, 1, 0, 0] + +COMPLETE_DATA_INFO_24 = DataInfo( + rgb_mean=[0.4404, 0.4440, 0.4327], # not sure + rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure + joint_names=StanExt_JOINT_NAMES_24, + hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23], + n_joints = 35, + n_keyp = 24, # 20, # 25, + n_bones = 24, + n_betas = 30, # 10, + image_size = 256, + trans_mean = STATISTICS['trans_mean'], + trans_std = STATISTICS['trans_std'], + flength_mean = STATISTICS['flength_mean'], + flength_std = STATISTICS['flength_std'], + pose_rot6d_mean = STATISTICS['pose_rot6d_mean'], + keypoint_weights = KEYPOINT_WEIGHTS_24 + ) + + diff --git a/src/configs/dataset_path_configs.py b/src/configs/dataset_path_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d5d63ee0ecd6064157f37983820c0db73d2fe7 --- /dev/null +++ b/src/configs/dataset_path_configs.py @@ -0,0 +1,21 @@ + + +import numpy as np +import os +import sys + +abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',)) + +# stanext dataset +# (1) path to stanext dataset +STAN_V12_ROOT_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset' + '/StanfordExtra_V12/' +IMG_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'StanExtV12_Images') +JSON_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', "StanfordExtra_v12.json") +STAN_V12_TRAIN_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'train_stanford_StanfordExtra_v12.npy') +STAN_V12_VAL_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'val_stanford_StanfordExtra_v12.npy') +STAN_V12_TEST_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'test_stanford_StanfordExtra_v12.npy') +# (2) path to related data such as breed indices and prepared predictions for withers, throat and eye keypoints +STANEXT_RELATED_DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'stanext_related_data') + +# test image crop dataset +TEST_IMAGE_CROP_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'datasets', 'test_image_crops') diff --git a/src/configs/dog_breeds/dog_breed_class.py b/src/configs/dog_breeds/dog_breed_class.py new file mode 100644 index 0000000000000000000000000000000000000000..282052164ec6ecb742d91d07ea564cc82cf70ab8 --- /dev/null +++ b/src/configs/dog_breeds/dog_breed_class.py @@ -0,0 +1,170 @@ + +import os +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import pandas as pd +import difflib +import json +import pickle as pkl +import csv +import numpy as np + + +# ----------------------------------------------------------------------------------------------------------------- # +class DogBreed(object): + def __init__(self, abbrev, name_akc=None, name_stanext=None, name_xlsx=None, path_akc=None, path_stanext=None, ind_in_xlsx=None, ind_in_xlsx_matrix=None, ind_in_stanext=None, clade=None): + self._abbrev = abbrev + self._name_xlsx = name_xlsx + self._name_akc = name_akc + self._name_stanext = name_stanext + self._path_stanext = path_stanext + self._additional_names = set() + if self._name_akc is not None: + self.add_akc_info(name_akc, path_akc) + if self._name_stanext is not None: + self.add_stanext_info(name_stanext, path_stanext, ind_in_stanext) + if self._name_xlsx is not None: + self.add_xlsx_info(name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade) + def add_xlsx_info(self, name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade): + assert (name_xlsx is not None) and (ind_in_xlsx is not None) and (ind_in_xlsx_matrix is not None) and (clade is not None) + self._name_xlsx = name_xlsx + self._ind_in_xlsx = ind_in_xlsx + self._ind_in_xlsx_matrix = ind_in_xlsx_matrix + self._clade = clade + def add_stanext_info(self, name_stanext, path_stanext, ind_in_stanext): + assert (name_stanext is not None) and (path_stanext is not None) and (ind_in_stanext is not None) + self._name_stanext = name_stanext + self._path_stanext = path_stanext + self._ind_in_stanext = ind_in_stanext + def add_akc_info(self, name_akc, path_akc): + assert (name_akc is not None) and (path_akc is not None) + self._name_akc = name_akc + self._path_akc = path_akc + def add_additional_names(self, name_list): + self._additional_names = self._additional_names.union(set(name_list)) + def add_text_info(self, text_height, text_weight, text_life_exp): + self._text_height = text_height + self._text_weight = text_weight + self._text_life_exp = text_life_exp + def get_datasets(self): + # all datasets in which this breed is found + datasets = set() + if self._name_akc is not None: + datasets.add('akc') + if self._name_stanext is not None: + datasets.add('stanext') + if self._name_xlsx is not None: + datasets.add('xlsx') + return datasets + def get_names(self): + # set of names for this breed + names = {self._abbrev, self._name_akc, self._name_stanext, self._name_xlsx, self._path_stanext}.union(self._additional_names) + names.discard(None) + return names + def get_names_as_pointing_dict(self): + # each name points to the abbreviation + names = self.get_names() + my_dict = {} + for name in names: + my_dict[name] = self._abbrev + return my_dict + def print_overview(self): + # print important information to get an overview of the class instance + if self._name_akc is not None: + name = self._name_akc + elif self._name_xlsx is not None: + name = self._name_xlsx + else: + name = self._name_stanext + print('----------------------------------------------------') + print('----- dog breed: ' + name ) + print('----------------------------------------------------') + print('[names]') + print(self.get_names()) + print('[datasets]') + print(self.get_datasets()) + # see https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class + print('[instance attributes]') + for attribute, value in self.__dict__.items(): + print(attribute, '=', value) + def use_dict_to_save_class_instance(self): + my_dict = {} + for attribute, value in self.__dict__.items(): + my_dict[attribute] = value + return my_dict + def use_dict_to_load_class_instance(self, my_dict): + for attribute, value in my_dict.items(): + setattr(self, attribute, value) + return + +# ----------------------------------------------------------------------------------------------------------------- # +def get_name_list_from_summary(summary): + name_from_abbrev_dict = {} + for breed in summary.values(): + abbrev = breed._abbrev + all_names = breed.get_names() + name_from_abbrev_dict[abbrev] = list(all_names) + return name_from_abbrev_dict +def get_partial_summary(summary, part): + assert part in ['xlsx', 'akc', 'stanext'] + partial_summary = {} + for key, value in summary.items(): + if (part == 'xlsx' and value._name_xlsx is not None) \ + or (part == 'akc' and value._name_akc is not None) \ + or (part == 'stanext' and value._name_stanext is not None): + partial_summary[key] = value + return partial_summary +def get_akc_but_not_stanext_partial_summary(summary): + partial_summary = {} + for key, value in summary.items(): + if value._name_akc is not None: + if value._name_stanext is None: + partial_summary[key] = value + return partial_summary + +# ----------------------------------------------------------------------------------------------------------------- # +def main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1): + with open(path_complete_abbrev_dict_v1, 'rb') as file: + complete_abbrev_dict = pkl.load(file) + with open(path_complete_summary_breeds_v1, 'rb') as file: + complete_summary_breeds_attributes_only = pkl.load(file) + + complete_summary_breeds = {} + for key, value in complete_summary_breeds_attributes_only.items(): + attributes_only = complete_summary_breeds_attributes_only[key] + complete_summary_breeds[key] = DogBreed(abbrev=attributes_only['_abbrev']) + complete_summary_breeds[key].use_dict_to_load_class_instance(attributes_only) + return complete_abbrev_dict, complete_summary_breeds + + +# ----------------------------------------------------------------------------------------------------------------- # +def load_similarity_matrix_raw(xlsx_path): + # --- LOAD EXCEL FILE FROM DOG BREED PAPER + xlsx = pd.read_excel(xlsx_path) + # create an array + abbrev_indices = {} + matrix_raw = np.zeros((168, 168)) + for ind in range(1, 169): + abbrev = xlsx[xlsx.columns[2]][ind] + abbrev_indices[abbrev] = ind-1 + for ind_col in range(0, 168): + for ind_row in range(0, 168): + matrix_raw[ind_col, ind_row] = float(xlsx[xlsx.columns[3+ind_col]][1+ind_row]) + return matrix_raw, abbrev_indices + + + +# ----------------------------------------------------------------------------------------------------------------- # +# ----------------------------------------------------------------------------------------------------------------- # +# load the (in advance created) final dict of dog breed classes +ROOT_PATH_BREED_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', 'data', 'breed_data') +path_complete_abbrev_dict_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_abbrev_dict_v2.pkl') +path_complete_summary_breeds_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_summary_breeds_v2.pkl') +COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS = main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1) +# load similarity matrix, data from: +# Parker H. G., Dreger D. L., Rimbault M., Davis B. W., Mullen A. B., Carpintero-Ramirez G., and Ostrander E. A. +# Genomic analyses reveal the influence of geographic origin, migration, and hybridization on modern dog breed +# development. Cell Reports, 4(19):697–708, 2017. +xlsx_path = os.path.join(ROOT_PATH_BREED_DATA, 'NIHMS866262-supplement-2.xlsx') +SIM_MATRIX_RAW, SIM_ABBREV_INDICES = load_similarity_matrix_raw(xlsx_path) + diff --git a/src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml b/src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d6854258d9554c8db7b2a9c56ac145d7619e52a --- /dev/null +++ b/src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml @@ -0,0 +1,23 @@ + +paths: + ROOT_OUT_PATH: './results/' + ROOT_CHECKPOINT_PATH: './checkpoint/' + MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt' + +smal: + SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc' + SMAL_KEYP_CONF: 'olive' # 'green' + +optim: + BATCH_SIZE: 12 + +params: + REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot' # 'add' + REF_DETACH_SHAPE: True + GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent' + SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent' + ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent' + +data: + DATASET: 'stanext24' + VAL_OPT: 'test' # 'val' diff --git a/src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml b/src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d06f52baeb3c2f33346d5d3850ab6a2d55ec2b0d --- /dev/null +++ b/src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml @@ -0,0 +1,23 @@ + +paths: + ROOT_OUT_PATH: './results/' + ROOT_CHECKPOINT_PATH: './checkpoint/' + MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt' + +smal: + SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc' + SMAL_KEYP_CONF: 'olive' # 'green' + +optim: + BATCH_SIZE: 12 + +params: + REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot' # 'add' + REF_DETACH_SHAPE: True + GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent' + SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent' + ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent' + +data: + DATASET: 'ImgCropList' + VAL_OPT: 'test' # 'val' diff --git a/src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml b/src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45b6c5b5a470fdb67c16845f73391a6aeb5e5700 --- /dev/null +++ b/src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml @@ -0,0 +1,31 @@ + +paths: + ROOT_OUT_PATH: './results/' + ROOT_CHECKPOINT_PATH: './checkpoint/' + MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt' + +smal: + SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc' + SMAL_KEYP_CONF: 'olive' # 'green' + +optim: + LR: 5e-5 # 5e-7 # (new) 5e-6 # 5e-5 # 5e-5 # 5e-4 + SCHEDULE: [150, 175, 200] # [220, 270] # [150, 175, 200] + GAMMA: 0.1 + MOMENTUM: 0 + WEIGHT_DECAY: 0 + EPOCHS: 220 # 300 + BATCH_SIZE: 14 # 12 # keep 12 (needs to be an even number, as we have a custom data sampler) + TRAIN_PARTS: 'refinement_model' # 'refinement_model_and_shape' # 'refinement_model' + +params: + REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot01' # 'multrot01' # 'multrot' # 'multrot_res34' # 'multrot' # 'add' + REF_DETACH_SHAPE: True + GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent' + SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent' + ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent' + +data: + DATASET: 'stanext24_withgc_csaddnonflatmorestanding' # 'stanext24_withgc_csaddnonflat' # 'stanext24_withgc_cs0' + SHORTEN_VAL_DATASET_TO: 600 # this is faster as we do not evaluate on the whole validation set + VAL_OPT: 'val' \ No newline at end of file diff --git a/src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json b/src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json new file mode 100644 index 0000000000000000000000000000000000000000..703ca0c6a5c87e1af9ac4868b8192fa7069033d0 --- /dev/null +++ b/src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json @@ -0,0 +1,20 @@ + + + +{ + "keyp_ref": 0.2, + "silh_ref": 50.0, + "pose_legs_side": 1.0, + "pose_legs_tors": 1.0, + "pose_tail_side": 0.0, + "pose_tail_tors": 0.0, + "pose_spine_side": 0.0, + "pose_spine_tors": 0.0, + "reg_trans": 0.0, + "reg_flength": 0.0, + "reg_pose": 0.0, + "gc_plane": 5.0, + "gc_blowplane": 5.0, + "gc_vertexwise": 10.0, + "gc_isflat": 0.5 +} diff --git a/src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json b/src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json new file mode 100644 index 0000000000000000000000000000000000000000..2c1fec805ebed99db2c7c5fc2becbaf64b36f1d4 --- /dev/null +++ b/src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json @@ -0,0 +1,77 @@ +{ + "silhouette": { + "weight": 40.0, + "weight_vshift": 20.0, + "value": 0.0 + }, + "keyp":{ + "weight": 0.2, + "weight_vshift": 0.01, + "value": 0.0 + }, + "pose_legs_side":{ + "weight": 1.0, + "weight_vshift": 1.0, + "value": 0.0 + }, + "pose_legs_tors":{ + "weight": 10.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "pose_tail_side":{ + "weight": 1, + "weight_vshift": 1, + "value": 0.0 + }, + "pose_tail_tors":{ + "weight": 10.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "pose_spine_side":{ + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + }, + "pose_spine_tors":{ + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + }, + "gc_plane":{ + "weight": 10.0, + "weight_vshift": 20.0, + "value": 0.0 + }, + "gc_belowplane":{ + "weight": 10.0, + "weight_vshift": 20.0, + "value": 0.0 + }, + "lapctf": { + "weight": 0.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "arap": { + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + }, + "edge": { + "weight": 0.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "normal": { + "weight": 0.0, + "weight_vshift": 1.0, + "value": 0.0 + }, + "laplacian": { + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + } +} \ No newline at end of file diff --git a/src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json b/src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json new file mode 100644 index 0000000000000000000000000000000000000000..2c1fec805ebed99db2c7c5fc2becbaf64b36f1d4 --- /dev/null +++ b/src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json @@ -0,0 +1,77 @@ +{ + "silhouette": { + "weight": 40.0, + "weight_vshift": 20.0, + "value": 0.0 + }, + "keyp":{ + "weight": 0.2, + "weight_vshift": 0.01, + "value": 0.0 + }, + "pose_legs_side":{ + "weight": 1.0, + "weight_vshift": 1.0, + "value": 0.0 + }, + "pose_legs_tors":{ + "weight": 10.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "pose_tail_side":{ + "weight": 1, + "weight_vshift": 1, + "value": 0.0 + }, + "pose_tail_tors":{ + "weight": 10.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "pose_spine_side":{ + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + }, + "pose_spine_tors":{ + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + }, + "gc_plane":{ + "weight": 10.0, + "weight_vshift": 20.0, + "value": 0.0 + }, + "gc_belowplane":{ + "weight": 10.0, + "weight_vshift": 20.0, + "value": 0.0 + }, + "lapctf": { + "weight": 0.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "arap": { + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + }, + "edge": { + "weight": 0.0, + "weight_vshift": 10.0, + "value": 0.0 + }, + "normal": { + "weight": 0.0, + "weight_vshift": 1.0, + "value": 0.0 + }, + "laplacian": { + "weight": 0.0, + "weight_vshift": 0.0, + "value": 0.0 + } +} \ No newline at end of file diff --git a/src/graph_networks/__init__.py b/src/graph_networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/graph_networks/graphcmr/__init__.py b/src/graph_networks/graphcmr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/graph_networks/graphcmr/get_downsampled_mesh_npz.py b/src/graph_networks/graphcmr/get_downsampled_mesh_npz.py new file mode 100644 index 0000000000000000000000000000000000000000..039ef10d99a7261587ef4c1c345dc436b186a90f --- /dev/null +++ b/src/graph_networks/graphcmr/get_downsampled_mesh_npz.py @@ -0,0 +1,84 @@ + +# try to use aenv_conda3 (maybe also export PYOPENGL_PLATFORM=osmesa) +# python src/graph_networks/graphcmr/get_downsampled_mesh_npz.py + +# see https://github.com/nkolot/GraphCMR/issues/35 + + +from __future__ import print_function +# import mesh_sampling +from psbody.mesh import Mesh, MeshViewer, MeshViewers +import numpy as np +import json +import os +import copy +import argparse +import pickle +import time +import sys +import trimesh + + + +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../")) +from barc_for_bite.src.graph_networks.graphcmr.pytorch_coma_mesh_operations import generate_transform_matrices +from barc_for_bite.src.configs.SMAL_configs import SMAL_MODEL_CONFIG +from barc_for_bite.src.smal_pytorch.smal_model.smal_torch_new import SMAL +# smal_model_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data/new_dog_models/my_smpl_00791_nadine_Jr_4_dog.pkl' + + +SMAL_MODEL_TYPE = '39dogs_diffsize' # '39dogs_diffsize' # '39dogs_norm' # 'barc' +smal_model_path = SMAL_MODEL_CONFIG[SMAL_MODEL_TYPE]['smal_model_path'] + +# data_path_root = "/is/cluster/work/nrueegg/icon_pifu_related/ICON/lib/graph_networks/graphcmr/data/" +data_path_root = "/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/" + +smal_dog_model_name = os.path.basename(smal_model_path).split('.pkl')[0] # 'my_smpl_SMBLD_nbj_v3' +suffix = "_template" +template_obj_path = data_path_root + smal_dog_model_name + suffix + ".obj" + +print("Loading smal .. ") +print(SMAL_MODEL_TYPE) +print(smal_model_path) + +smal = SMAL(smal_model_type=SMAL_MODEL_TYPE, template_name='neutral') +smal_verts = smal.v_template.detach().cpu().numpy() # (3889, 3) +smal_faces = smal.f # (7774, 3) +smal_trimesh = trimesh.base.Trimesh(vertices=smal_verts, faces=smal_faces, process=False, maintain_order=True) +smal_trimesh.export(file_obj=template_obj_path) # file_type='obj') + + +print("Loading data .. ") +reference_mesh_file = template_obj_path # 'data/barc_neutral_vertices.obj' # 'data/smpl_neutral_vertices.obj' +reference_mesh = Mesh(filename=reference_mesh_file) + +# ds_factors = [4, 4] # ds_factors = [4,1] # Sampling factor of the mesh at each stage of sampling +ds_factors = [4, 4, 4, 4] +print("Generating Transform Matrices ..") + + +# Generates adjecency matrices A, downsampling matrices D, and upsamling matrices U by sampling +# the mesh 4 times. Each time the mesh is sampled by a factor of 4 + +# M,A,D,U = mesh_sampling.generate_transform_matrices(reference_mesh, ds_factors) +M,A,D,U = generate_transform_matrices(reference_mesh, ds_factors) + +# REMARK: there is a warning: +# lib/graph_networks/graphcmr/../../../lib/graph_networks/graphcmr/pytorch_coma_mesh_operations.py:237: FutureWarning: `rcond` parameter will +# change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions. +# To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`. + + +print(type(A)) +np.savez(data_path_root + 'mesh_downsampling_' + smal_dog_model_name + suffix + '.npz', A = A, D = D, U = U) +np.savez(data_path_root + 'meshes/' + 'mesh_downsampling_meshes' + smal_dog_model_name + suffix + '.npz', M = M) + +for ind_m, my_mesh in enumerate(M): + new_suffix = '_template_downsampled' + str(ind_m) + my_mesh_tri = trimesh.Trimesh(vertices=my_mesh.v, faces=my_mesh.f, process=False, maintain_order=True) + my_mesh_tri.export(data_path_root + 'meshes/' + 'mesh_downsampling_meshes' + smal_dog_model_name + new_suffix + '.obj') + + + + + diff --git a/src/graph_networks/graphcmr/graph_cnn.py b/src/graph_networks/graphcmr/graph_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..750ec2c4937dd5f01d9f643acbb4690e8e706f88 --- /dev/null +++ b/src/graph_networks/graphcmr/graph_cnn.py @@ -0,0 +1,53 @@ +""" +code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py +This file contains the Definition of GraphCNN +GraphCNN includes ResNet50 as a submodule +""" +from __future__ import division + +import torch +import torch.nn as nn + +from .graph_layers import GraphResBlock, GraphLinear +from .resnet import resnet50 + +class GraphCNN(nn.Module): + + def __init__(self, A, ref_vertices, num_layers=5, num_channels=512): + super(GraphCNN, self).__init__() + self.A = A + self.ref_vertices = ref_vertices + self.resnet = resnet50(pretrained=True) + layers = [GraphLinear(3 + 2048, 2 * num_channels)] + layers.append(GraphResBlock(2 * num_channels, num_channels, A)) + for i in range(num_layers): + layers.append(GraphResBlock(num_channels, num_channels, A)) + self.shape = nn.Sequential(GraphResBlock(num_channels, 64, A), + GraphResBlock(64, 32, A), + nn.GroupNorm(32 // 8, 32), + nn.ReLU(inplace=True), + GraphLinear(32, 3)) + self.gc = nn.Sequential(*layers) + self.camera_fc = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels), + nn.ReLU(inplace=True), + GraphLinear(num_channels, 1), + nn.ReLU(inplace=True), + nn.Linear(A.shape[0], 3)) + + def forward(self, image): + """Forward pass + Inputs: + image: size = (B, 3, 224, 224) + Returns: + Regressed (subsampled) non-parametric shape: size = (B, 1723, 3) + Weak-perspective camera: size = (B, 3) + """ + batch_size = image.shape[0] + ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1) + image_resnet = self.resnet(image) + image_enc = image_resnet.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1]) + x = torch.cat([ref_vertices, image_enc], dim=1) + x = self.gc(x) + shape = self.shape(x) + camera = self.camera_fc(x).view(batch_size, 3) + return shape, camera \ No newline at end of file diff --git a/src/graph_networks/graphcmr/graph_cnn_groundcontact.py b/src/graph_networks/graphcmr/graph_cnn_groundcontact.py new file mode 100644 index 0000000000000000000000000000000000000000..bc358cef22022b0b16089a5b9d8bed49b112c6d8 --- /dev/null +++ b/src/graph_networks/graphcmr/graph_cnn_groundcontact.py @@ -0,0 +1,101 @@ +""" +code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py +This file contains the Definition of GraphCNN +GraphCNN includes ResNet50 as a submodule +""" +from __future__ import division + +import torch +import torch.nn as nn + +# from .resnet import resnet50 +import torchvision.models as models + + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) +from src.graph_networks.graphcmr.utils_mesh import Mesh +from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear + + +class GraphCNN(nn.Module): + + def __init__(self, A, ref_vertices, n_resnet_in, n_resnet_out, num_layers=5, num_channels=512): + super(GraphCNN, self).__init__() + self.A = A + self.ref_vertices = ref_vertices + # self.resnet = resnet50(pretrained=True) + # -> within the GraphCMR network they ignore the last fully connected layer + # replace the first layer + self.resnet = models.resnet34(pretrained=False) + n_in = 3 + 1 + self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # replace the last layer + self.resnet.fc = nn.Linear(512, n_resnet_out) + + + layers = [GraphLinear(3 + n_resnet_out, 2 * num_channels)] # [GraphLinear(3 + 2048, 2 * num_channels)] + layers.append(GraphResBlock(2 * num_channels, num_channels, A)) + for i in range(num_layers): + layers.append(GraphResBlock(num_channels, num_channels, A)) + self.n_out_gc = 2 # two labels per vertex + self.gc = nn.Sequential(GraphResBlock(num_channels, 64, A), + GraphResBlock(64, 32, A), + nn.GroupNorm(32 // 8, 32), + nn.ReLU(inplace=True), + GraphLinear(32, self.n_out_gc)) + self.gcnn = nn.Sequential(*layers) + self.n_out_flatground = 1 + self.flat_ground = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels), + nn.ReLU(inplace=True), + GraphLinear(num_channels, 1), + nn.ReLU(inplace=True), + nn.Linear(A.shape[0], self.n_out_flatground)) + + def forward(self, image): + """Forward pass + Inputs: + image: size = (B, 3, 256, 256) + Returns: + Regressed (subsampled) non-parametric shape: size = (B, 1723, 3) + Weak-perspective camera: size = (B, 3) + """ + # import pdb; pdb.set_trace() + + batch_size = image.shape[0] + ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) + image_resnet = self.resnet(image) # (bs, 512) + image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973) + x = torch.cat([ref_vertices, image_enc], dim=1) + x = self.gcnn(x) # (bs, 512, 973) + ground_contact = self.gc(x) # (bs, 2, 973) + ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1) + return ground_contact, ground_flatness + + + + +# how to use it: +# +# from src.graph_networks.graphcmr.utils_mesh import Mesh +# +# create Mesh object +# self.mesh = Mesh() +# self.faces = self.mesh.faces.to(self.device) +# +# create GraphCNN +# self.graph_cnn = GraphCNN(self.mesh.adjmat, +# self.mesh.ref_vertices.t(), +# num_channels=self.options.num_channels, +# num_layers=self.options.num_layers +# ).to(self.device) +# ------------ +# +# Feed image in the GraphCNN +# Returns subsampled mesh and camera parameters +# pred_vertices_sub, pred_camera = self.graph_cnn(images) +# +# Upsample mesh in the original size +# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2)) +# \ No newline at end of file diff --git a/src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py b/src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6bcc2e099adbd0aad6a3da7bfe1efdcd31e4f0 --- /dev/null +++ b/src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py @@ -0,0 +1,174 @@ +""" +code from + https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py + https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py +This file contains the Definition of GraphCNN +GraphCNN includes ResNet50 as a submodule +""" +from __future__ import division + +import torch +import torch.nn as nn + +# from .resnet import resnet50 +import torchvision.models as models + + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) +from src.graph_networks.graphcmr.utils_mesh import Mesh +from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear + + +class GraphCNNMS(nn.Module): + + def __init__(self, mesh, num_downsample=0, num_layers=5, n_resnet_out=256, num_channels=256): + ''' + Args: + mesh: mesh data that store the adjacency matrix + num_channels: number of channels of GCN + num_downsample: number of downsampling of the input mesh + ''' + + super(GraphCNNMS, self).__init__() + + self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled + # self.num_layers = len(self.A) - 1 + self.num_layers = num_layers + assert self.num_layers <= len(self.A) - 1 + print("Number of downsampling layer: {}".format(self.num_layers)) + self.num_downsample = num_downsample + self.n_resnet_out = n_resnet_out + + + ''' + self.use_pret_res = use_pret_res + # self.resnet = resnet50(pretrained=True) + # -> within the GraphCMR network they ignore the last fully connected layer + # replace the first layer + self.resnet = models.resnet34(pretrained=self.use_pret_res) + if (self.use_pret_res) and (n_resnet_in == 3): + print('use full pretrained resnet including first layer!') + else: + self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # replace the last layer + self.resnet.fc = nn.Linear(512, n_resnet_out) + ''' + + self.lin1 = GraphLinear(3 + n_resnet_out, 2 * num_channels) + self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0]) + encode_layers = [] + decode_layers = [] + + for i in range(self.num_layers + 1): # range(len(self.A)): + encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i])) + + decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels, + self.A[self.num_layers - i])) + current_channels = (i+1)*num_channels + # number of channels for the input is different because of the concatenation operation + self.n_out_gc = 2 # two labels per vertex + self.gc = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]), + GraphResBlock(64, 32, self.A[0]), + nn.GroupNorm(32 // 8, 32), + nn.ReLU(inplace=True), + GraphLinear(32, self.n_out_gc)) + + ''' + self.n_out_flatground = 2 + self.flat_ground = nn.Sequential(nn.GroupNorm(current_channels // 8, current_channels), + nn.ReLU(inplace=True), + GraphLinear(current_channels, 1), + nn.ReLU(inplace=True), + nn.Linear(A.shape[0], self.n_out_flatground)) + ''' + + self.encoder = nn.Sequential(*encode_layers) + self.decoder = nn.Sequential(*decode_layers) + self.mesh = mesh + + + + + def forward(self, image_enc): + """Forward pass + Inputs: + image_enc: size = (B, self.n_resnet_out) + Returns: + Regressed (subsampled) non-parametric shape: size = (B, 1723, 3) + Weak-perspective camera: size = (B, 3) + """ + # import pdb; pdb.set_trace() + + batch_size = image_enc.shape[0] + # ref_vertices = (self.mesh.get_ref_vertices(n=self.num_downsample).t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) + ref_vertices = (self.mesh.ref_vertices.t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) + '''image_resnet = self.resnet(image) # (bs, 512)''' + image_enc_prep = image_enc.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973) + + # prepare network input + # -> for each node we feed the location of the vertex in the template mesh and an image encoding + x = torch.cat([ref_vertices, image_enc_prep], dim=1) + x = self.lin1(x) + x = self.res1(x) + x_ = [x] + output_list = [] + for i in range(self.num_layers + 1): + if i == self.num_layers: + x = self.encoder[i](x) + else: + x = self.encoder[i](x) + x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1) + x = x.transpose(1, 2) + if i < self.num_layers-1: + x_.append(x) + for i in range(self.num_layers + 1): + if i == self.num_layers: + x = self.decoder[i](x) + output_list.append(x) + else: + x = self.decoder[i](x) + output_list.append(x) + x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample, + n2=self.num_layers-i-1+self.num_downsample) + x = x.transpose(1, 2) + x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder + + ground_contact = self.gc(x) + + ''' + ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1) + ''' + + return ground_contact, output_list # , ground_flatness + + + + + + + +# how to use it: +# +# from src.graph_networks.graphcmr.utils_mesh import Mesh +# +# create Mesh object +# self.mesh = Mesh() +# self.faces = self.mesh.faces.to(self.device) +# +# create GraphCNN +# self.graph_cnn = GraphCNN(self.mesh.adjmat, +# self.mesh.ref_vertices.t(), +# num_channels=self.options.num_channels, +# num_layers=self.options.num_layers +# ).to(self.device) +# ------------ +# +# Feed image in the GraphCNN +# Returns subsampled mesh and camera parameters +# pred_vertices_sub, pred_camera = self.graph_cnn(images) +# +# Upsample mesh in the original size +# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2)) +# \ No newline at end of file diff --git a/src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py b/src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..85bc50cc0eced144ad832ca6e9a97490f0c77800 --- /dev/null +++ b/src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py @@ -0,0 +1,170 @@ +""" +code from + https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py + https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py +This file contains the Definition of GraphCNN +GraphCNN includes ResNet50 as a submodule +""" +from __future__ import division + +import torch +import torch.nn as nn + +# from .resnet import resnet50 +import torchvision.models as models + + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) +from src.graph_networks.graphcmr.utils_mesh import Mesh +from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear + + +class GraphCNNMS(nn.Module): + + def __init__(self, mesh, num_downsample=0, num_layers=5, n_resnet_in=3, n_resnet_out=256, num_channels=256, use_pret_res=False): + ''' + Args: + mesh: mesh data that store the adjacency matrix + num_channels: number of channels of GCN + num_downsample: number of downsampling of the input mesh + ''' + + super(GraphCNNMS, self).__init__() + + self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled + # self.num_layers = len(self.A) - 1 + self.num_layers = num_layers + assert self.num_layers <= len(self.A) - 1 + print("Number of downsampling layer: {}".format(self.num_layers)) + self.num_downsample = num_downsample + self.use_pret_res = use_pret_res + + # self.resnet = resnet50(pretrained=True) + # -> within the GraphCMR network they ignore the last fully connected layer + # replace the first layer + self.resnet = models.resnet34(pretrained=self.use_pret_res) + if (self.use_pret_res) and (n_resnet_in == 3): + print('use full pretrained resnet including first layer!') + else: + self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + # replace the last layer + self.resnet.fc = nn.Linear(512, n_resnet_out) + + self.lin1 = GraphLinear(3 + n_resnet_out, 2 * num_channels) + self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0]) + encode_layers = [] + decode_layers = [] + + for i in range(self.num_layers + 1): # range(len(self.A)): + encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i])) + + decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels, + self.A[self.num_layers - i])) + current_channels = (i+1)*num_channels + # number of channels for the input is different because of the concatenation operation + self.n_out_gc = 2 # two labels per vertex + self.gc = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]), + GraphResBlock(64, 32, self.A[0]), + nn.GroupNorm(32 // 8, 32), + nn.ReLU(inplace=True), + GraphLinear(32, self.n_out_gc)) + + ''' + self.n_out_flatground = 2 + self.flat_ground = nn.Sequential(nn.GroupNorm(current_channels // 8, current_channels), + nn.ReLU(inplace=True), + GraphLinear(current_channels, 1), + nn.ReLU(inplace=True), + nn.Linear(A.shape[0], self.n_out_flatground)) + ''' + + self.encoder = nn.Sequential(*encode_layers) + self.decoder = nn.Sequential(*decode_layers) + self.mesh = mesh + + + + + def forward(self, image): + """Forward pass + Inputs: + image: size = (B, 3, 256, 256) + Returns: + Regressed (subsampled) non-parametric shape: size = (B, 1723, 3) + Weak-perspective camera: size = (B, 3) + """ + # import pdb; pdb.set_trace() + + batch_size = image.shape[0] + # ref_vertices = (self.mesh.get_ref_vertices(n=self.num_downsample).t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) + ref_vertices = (self.mesh.ref_vertices.t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) + image_resnet = self.resnet(image) # (bs, 512) + image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973) + + # prepare network input + # -> for each node we feed the location of the vertex in the template mesh and an image encoding + x = torch.cat([ref_vertices, image_enc], dim=1) + x = self.lin1(x) + x = self.res1(x) + x_ = [x] + output_list = [] + for i in range(self.num_layers + 1): + if i == self.num_layers: + x = self.encoder[i](x) + else: + x = self.encoder[i](x) + x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1) + x = x.transpose(1, 2) + if i < self.num_layers-1: + x_.append(x) + for i in range(self.num_layers + 1): + if i == self.num_layers: + x = self.decoder[i](x) + output_list.append(x) + else: + x = self.decoder[i](x) + output_list.append(x) + x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample, + n2=self.num_layers-i-1+self.num_downsample) + x = x.transpose(1, 2) + x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder + + ground_contact = self.gc(x) + + ''' + ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1) + ''' + + return ground_contact, output_list # , ground_flatness + + + + + + + +# how to use it: +# +# from src.graph_networks.graphcmr.utils_mesh import Mesh +# +# create Mesh object +# self.mesh = Mesh() +# self.faces = self.mesh.faces.to(self.device) +# +# create GraphCNN +# self.graph_cnn = GraphCNN(self.mesh.adjmat, +# self.mesh.ref_vertices.t(), +# num_channels=self.options.num_channels, +# num_layers=self.options.num_layers +# ).to(self.device) +# ------------ +# +# Feed image in the GraphCNN +# Returns subsampled mesh and camera parameters +# pred_vertices_sub, pred_camera = self.graph_cnn(images) +# +# Upsample mesh in the original size +# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2)) +# \ No newline at end of file diff --git a/src/graph_networks/graphcmr/graph_layers.py b/src/graph_networks/graphcmr/graph_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0af7cadd08cd3da5e8d3011791231c53ffbb6d57 --- /dev/null +++ b/src/graph_networks/graphcmr/graph_layers.py @@ -0,0 +1,125 @@ +""" +code from https://github.com/nkolot/GraphCMR/blob/master/models/graph_layers.py +This file contains definitions of layers used to build the GraphCNN +""" +from __future__ import division + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class GraphConvolution(nn.Module): + """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" + def __init__(self, in_features, out_features, adjmat, bias=True): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.adjmat = adjmat + self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) + if bias: + self.bias = nn.Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + # stdv = 1. / math.sqrt(self.weight.size(1)) + stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, x): + if x.ndimension() == 2: + support = torch.matmul(x, self.weight) + output = torch.matmul(self.adjmat, support) + if self.bias is not None: + output = output + self.bias + return output + else: + output = [] + for i in range(x.shape[0]): + support = torch.matmul(x[i], self.weight) + # output.append(torch.matmul(self.adjmat, support)) + output.append(spmm(self.adjmat, support)) + output = torch.stack(output, dim=0) + if self.bias is not None: + output = output + self.bias + return output + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ')' + +class GraphLinear(nn.Module): + """ + Generalization of 1x1 convolutions on Graphs + """ + def __init__(self, in_channels, out_channels): + super(GraphLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.W = nn.Parameter(torch.FloatTensor(out_channels, in_channels)) + self.b = nn.Parameter(torch.FloatTensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + w_stdv = 1 / (self.in_channels * self.out_channels) + self.W.data.uniform_(-w_stdv, w_stdv) + self.b.data.uniform_(-w_stdv, w_stdv) + + def forward(self, x): + return torch.matmul(self.W[None, :], x) + self.b[None, :, None] + +class GraphResBlock(nn.Module): + """ + Graph Residual Block similar to the Bottleneck Residual Block in ResNet + """ + + def __init__(self, in_channels, out_channels, A): + super(GraphResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lin1 = GraphLinear(in_channels, out_channels // 2) + self.conv = GraphConvolution(out_channels // 2, out_channels // 2, A) + self.lin2 = GraphLinear(out_channels // 2, out_channels) + self.skip_conv = GraphLinear(in_channels, out_channels) + self.pre_norm = nn.GroupNorm(in_channels // 8, in_channels) + self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) + self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) + + def forward(self, x): + y = F.relu(self.pre_norm(x)) + y = self.lin1(y) + + y = F.relu(self.norm1(y)) + y = self.conv(y.transpose(1,2)).transpose(1,2) + + y = F.relu(self.norm2(y)) + y = self.lin2(y) + if self.in_channels != self.out_channels: + x = self.skip_conv(x) + return x+y + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + return SparseMM.apply(sparse, dense) \ No newline at end of file diff --git a/src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py b/src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..62392ac754bbf3a1486af055682c5fbaf31749b1 --- /dev/null +++ b/src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py @@ -0,0 +1,97 @@ + +""" +code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py +This file contains the Definition of GraphCNN +GraphCNN includes ResNet50 as a submodule +""" +from __future__ import division + +import torch +import torch.nn as nn + +from model.networks.graph_layers import GraphResBlock, GraphLinear +from smal.mesh import Mesh +from smal.smal_torch import SMAL + +# encoder-decoder structured GCN with skip connections +class GraphCNN_hg(nn.Module): + + def __init__(self, mesh, num_channels=256, local_feat=False, num_downsample=0): + ''' + Args: + mesh: mesh data that store the adjacency matrix + num_channels: number of channels of GCN + local_feat: whether use local feature for refinement + num_downsample: number of downsampling of the input mesh + ''' + super(GraphCNN_hg, self).__init__() + self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled + self.num_layers = len(self.A) - 1 + print("Number of downsampling layer: {}".format(self.num_layers)) + self.num_downsample = num_downsample + if local_feat: + self.lin1 = GraphLinear(3 + 2048 + 3840, 2 * num_channels) + else: + self.lin1 = GraphLinear(3 + 2048, 2 * num_channels) + self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0]) + encode_layers = [] + decode_layers = [] + + for i in range(len(self.A)): + encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i])) + + decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels, + self.A[len(self.A) - i - 1])) + current_channels = (i+1)*num_channels + # number of channels for the input is different because of the concatenation operation + self.shape = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]), + GraphResBlock(64, 32, self.A[0]), + nn.GroupNorm(32 // 8, 32), + nn.ReLU(inplace=True), + GraphLinear(32, 3)) + + self.encoder = nn.Sequential(*encode_layers) + self.decoder = nn.Sequential(*decode_layers) + self.mesh = mesh + + def forward(self, verts_c, img_fea_global, img_fea_multiscale=None, points_local=None): + ''' + Args: + verts_c: vertices from the coarse estimation + img_fea_global: global feature for mesh refinement + img_fea_multiscale: multi-scale feature from the encoder, used for local feature extraction + points_local: 2D keypoint for local feature extraction + Returns: refined mesh + ''' + batch_size = img_fea_global.shape[0] + ref_vertices = verts_c.transpose(1, 2) + image_enc = img_fea_global.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1]) + if points_local is not None: + feat_local = torch.nn.functional.grid_sample(img_fea_multiscale, points_local) + x = torch.cat([ref_vertices, image_enc, feat_local.squeeze(2)], dim=1) + else: + x = torch.cat([ref_vertices, image_enc], dim=1) + x = self.lin1(x) + x = self.res1(x) + x_ = [x] + for i in range(self.num_layers + 1): + if i == self.num_layers: + x = self.encoder[i](x) + else: + x = self.encoder[i](x) + x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1) + x = x.transpose(1, 2) + if i < self.num_layers-1: + x_.append(x) + for i in range(self.num_layers + 1): + if i == self.num_layers: + x = self.decoder[i](x) + else: + x = self.decoder[i](x) + x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample, + n2=self.num_layers-i-1+self.num_downsample) + x = x.transpose(1, 2) + x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder + + shape = self.shape(x) + return shape \ No newline at end of file diff --git a/src/graph_networks/graphcmr/my_remarks.txt b/src/graph_networks/graphcmr/my_remarks.txt new file mode 100644 index 0000000000000000000000000000000000000000..640085191f8793903a8f9546ac0cdf8d12902aaa --- /dev/null +++ b/src/graph_networks/graphcmr/my_remarks.txt @@ -0,0 +1,11 @@ + +this folder contains code from https://github.com/nkolot/GraphCMR/tree/master/models + + +other (newer) networks operating on meshes such as SMAL would be: + https://github.com/microsoft/MeshTransformer + https://github.com/microsoft/MeshGraphormer + +see also: + https://arxiv.org/pdf/2112.01554.pdf, page 13 + (Neural Head Avatars from Monocular RGB Videos) diff --git a/src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py b/src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..fa54095a09891aebf0506de46a5d47b589108038 --- /dev/null +++ b/src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py @@ -0,0 +1,282 @@ +# code from https://github.com/pixelite1201/pytorch_coma/blob/master/mesh_operations.py + +import math +import heapq +import numpy as np +import scipy.sparse as sp +from psbody.mesh import Mesh + +def row(A): + return A.reshape((1, -1)) + +def col(A): + return A.reshape((-1, 1)) + +def get_vert_connectivity(mesh_v, mesh_f): + """Returns a sparse matrix (of size #verts x #verts) where each nonzero + element indicates a neighborhood relation. For example, if there is a + nonzero element in position (15,12), that means vertex 15 is connected + by an edge to vertex 12.""" + + vpv = sp.csc_matrix((len(mesh_v),len(mesh_v))) + + # for each column in the faces... + for i in range(3): + IS = mesh_f[:,i] + JS = mesh_f[:,(i+1)%3] + data = np.ones(len(IS)) + ij = np.vstack((row(IS.flatten()), row(JS.flatten()))) + mtx = sp.csc_matrix((data, ij), shape=vpv.shape) + vpv = vpv + mtx + mtx.T + + return vpv + +def get_vertices_per_edge(mesh_v, mesh_f): + """Returns an Ex2 array of adjacencies between vertices, where + each element in the array is a vertex index. Each edge is included + only once. If output of get_faces_per_edge is provided, this is used to + avoid call to get_vert_connectivity()""" + + vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f)) + result = np.hstack((col(vc.row), col(vc.col))) + result = result[result[:,0] < result[:,1]] # for uniqueness + + return result + + +def vertex_quadrics(mesh): + """Computes a quadric for each vertex in the Mesh. + see also: + https://www.cs.cmu.edu/~./garland/Papers/quadrics.pdf + https://users.csc.calpoly.edu/~zwood/teaching/csc570/final06/jseeba/ + Returns: + v_quadrics: an (N x 4 x 4) array, where N is # vertices. + """ + + # Allocate quadrics + v_quadrics = np.zeros((len(mesh.v), 4, 4,)) + + # For each face... + for f_idx in range(len(mesh.f)): + + # Compute normalized plane equation for that face + vert_idxs = mesh.f[f_idx] + verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 1]).reshape(-1, 1))) + u, s, v = np.linalg.svd(verts) + eq = v[-1, :].reshape(-1, 1) + eq = eq / (np.linalg.norm(eq[0:3])) + + # Add the outer product of the plane equation to the + # quadrics of the vertices for this face + for k in range(3): + v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq) + + return v_quadrics + +def _get_sparse_transform(faces, num_original_verts): + verts_left = np.unique(faces.flatten()) + IS = np.arange(len(verts_left)) + JS = verts_left + data = np.ones(len(JS)) + + mp = np.arange(0, np.max(faces.flatten()) + 1) + mp[JS] = IS + new_faces = mp[faces.copy().flatten()].reshape((-1, 3)) + + ij = np.vstack((IS.flatten(), JS.flatten())) + mtx = sp.csc_matrix((data, ij), shape=(len(verts_left) , num_original_verts )) + + return (new_faces, mtx) + +def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None): + """Return a simplified version of this mesh. + + A Qslim-style approach is used here. + + :param factor: fraction of the original vertices to retain + :param n_verts_desired: number of the original vertices to retain + :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix + """ + + if factor is None and n_verts_desired is None: + raise Exception('Need either factor or n_verts_desired.') + + if n_verts_desired is None: + n_verts_desired = math.ceil(len(mesh.v) * factor) + + Qv = vertex_quadrics(mesh) + + # fill out a sparse matrix indicating vertex-vertex adjacency + # from psbody.mesh.topology.connectivity import get_vertices_per_edge + vert_adj = get_vertices_per_edge(mesh.v, mesh.f) + # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v))) + # for f_idx in range(len(mesh.f)): + # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1 + + vert_adj = sp.csc_matrix((vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])), shape=(len(mesh.v), len(mesh.v))) + vert_adj = vert_adj + vert_adj.T + vert_adj = vert_adj.tocoo() + + def collapse_cost(Qv, r, c, v): + Qsum = Qv[r, :, :] + Qv[c, :, :] + p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1))) + p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1))) + + destroy_c_cost = p1.T.dot(Qsum).dot(p1) + destroy_r_cost = p2.T.dot(Qsum).dot(p2) + result = { + 'destroy_c_cost': destroy_c_cost, + 'destroy_r_cost': destroy_r_cost, + 'collapse_cost': min([destroy_c_cost, destroy_r_cost]), + 'Qsum': Qsum} + return result + + # construct a queue of edges with costs + queue = [] + for k in range(vert_adj.nnz): + r = vert_adj.row[k] + c = vert_adj.col[k] + + if r > c: + continue + + cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost'] + heapq.heappush(queue, (cost, (r, c))) + + # decimate + collapse_list = [] + nverts_total = len(mesh.v) + faces = mesh.f.copy() + while nverts_total > n_verts_desired: + e = heapq.heappop(queue) + r = e[1][0] + c = e[1][1] + if r == c: + continue + + cost = collapse_cost(Qv, r, c, mesh.v) + if cost['collapse_cost'] > e[0]: + heapq.heappush(queue, (cost['collapse_cost'], e[1])) + # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost']) + continue + else: + + # update old vert idxs to new one, + # in queue and in face list + if cost['destroy_c_cost'] < cost['destroy_r_cost']: + to_destroy = c + to_keep = r + else: + to_destroy = r + to_keep = c + + collapse_list.append([to_keep, to_destroy]) + + # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx + np.place(faces, faces == to_destroy, to_keep) + + # same for queue + which1 = [idx for idx in range(len(queue)) if queue[idx][1][0] == to_destroy] + which2 = [idx for idx in range(len(queue)) if queue[idx][1][1] == to_destroy] + for k in which1: + queue[k] = (queue[k][0], (to_keep, queue[k][1][1])) + for k in which2: + queue[k] = (queue[k][0], (queue[k][1][0], to_keep)) + + Qv[r, :, :] = cost['Qsum'] + Qv[c, :, :] = cost['Qsum'] + + a = faces[:, 0] == faces[:, 1] + b = faces[:, 1] == faces[:, 2] + c = faces[:, 2] == faces[:, 0] + + # remove degenerate faces + def logical_or3(x, y, z): + return np.logical_or(x, np.logical_or(y, z)) + + faces_to_keep = np.logical_not(logical_or3(a, b, c)) + faces = faces[faces_to_keep, :].copy() + + nverts_total = (len(np.unique(faces.flatten()))) + + new_faces, mtx = _get_sparse_transform(faces, len(mesh.v)) + return new_faces, mtx + + +def setup_deformation_transfer(source, target, use_normals=False): + rows = np.zeros(3 * target.v.shape[0]) + cols = np.zeros(3 * target.v.shape[0]) + coeffs_v = np.zeros(3 * target.v.shape[0]) + coeffs_n = np.zeros(3 * target.v.shape[0]) + + nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree().nearest(target.v, True) + nearest_faces = nearest_faces.ravel().astype(np.int64) + nearest_parts = nearest_parts.ravel().astype(np.int64) + nearest_vertices = nearest_vertices.ravel() + + for i in range(target.v.shape[0]): + # Closest triangle index + f_id = nearest_faces[i] + # Closest triangle vertex ids + nearest_f = source.f[f_id] + + # Closest surface point + nearest_v = nearest_vertices[3 * i:3 * i + 3] + # Distance vector to the closest surface point + dist_vec = target.v[i] - nearest_v + + rows[3 * i:3 * i + 3] = i * np.ones(3) + cols[3 * i:3 * i + 3] = nearest_f + + n_id = nearest_parts[i] + if n_id == 0: + # Closest surface point in triangle + A = np.vstack((source.v[nearest_f])).T + coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v)[0] + elif n_id > 0 and n_id <= 3: + # Closest surface point on edge + A = np.vstack((source.v[nearest_f[n_id - 1]], source.v[nearest_f[n_id % 3]])).T + tmp_coeffs = np.linalg.lstsq(A, target.v[i])[0] + coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0] + coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1] + else: + # Closest surface point a vertex + coeffs_v[3 * i + n_id - 4] = 1.0 + + # if use_normals: + # A = np.vstack((vn[nearest_f])).T + # coeffs_n[3 * i:3 * i + 3] = np.linalg.lstsq(A, dist_vec)[0] + + #coeffs = np.hstack((coeffs_v, coeffs_n)) + #rows = np.hstack((rows, rows)) + #cols = np.hstack((cols, source.v.shape[0] + cols)) + matrix = sp.csc_matrix((coeffs_v, (rows, cols)), shape=(target.v.shape[0], source.v.shape[0])) + return matrix + + +def generate_transform_matrices(mesh, factors): + """Generates len(factors) meshes, each of them is scaled by factors[i] and + computes the transformations between them. + + Returns: + M: a set of meshes downsampled from mesh by a factor specified in factors. + A: Adjacency matrix for each of the meshes + D: Downsampling transforms between each of the meshes + U: Upsampling transforms between each of the meshes + """ + + factors = map(lambda x: 1.0 / x, factors) + M, A, D, U = [], [], [], [] + A.append(get_vert_connectivity(mesh.v, mesh.f).tocoo()) + M.append(mesh) + + for i,factor in enumerate(factors): + ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor) + D.append(ds_D.tocoo()) + new_mesh_v = ds_D.dot(M[-1].v) + new_mesh = Mesh(v=new_mesh_v, f=ds_f) + M.append(new_mesh) + A.append(get_vert_connectivity(new_mesh.v, new_mesh.f).tocoo()) + U.append(setup_deformation_transfer(M[-1], M[-2]).tocoo()) + + return M, A, D, U \ No newline at end of file diff --git a/src/graph_networks/graphcmr/utils_mesh.py b/src/graph_networks/graphcmr/utils_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..6a86a3eb6f88ab0298aa13aeaae4cdccab66f800 --- /dev/null +++ b/src/graph_networks/graphcmr/utils_mesh.py @@ -0,0 +1,138 @@ +# code from https://github.com/nkolot/GraphCMR/blob/master/utils/mesh.py + +from __future__ import division +import torch +import numpy as np +import scipy.sparse + +# from models import SMPL +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from graph_networks.graphcmr.graph_layers import spmm + +def scipy_to_pytorch(A, U, D): + """Convert scipy sparse matrices to pytorch sparse matrix.""" + ptU = [] + ptD = [] + + for i in range(len(U)): + u = scipy.sparse.coo_matrix(U[i]) + i = torch.LongTensor(np.array([u.row, u.col])) + v = torch.FloatTensor(u.data) + ptU.append(torch.sparse.FloatTensor(i, v, u.shape)) + + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) + + return ptU, ptD + + +def adjmat_sparse(adjmat, nsize=1): + """Create row-normalized sparse graph adjacency matrix.""" + adjmat = scipy.sparse.csr_matrix(adjmat) + if nsize > 1: + orig_adjmat = adjmat.copy() + for _ in range(1, nsize): + adjmat = adjmat * orig_adjmat + adjmat.data = np.ones_like(adjmat.data) + for i in range(adjmat.shape[0]): + adjmat[i,i] = 1 + num_neighbors = np.array(1 / adjmat.sum(axis=-1)) + adjmat = adjmat.multiply(num_neighbors) + adjmat = scipy.sparse.coo_matrix(adjmat) + row = adjmat.row + col = adjmat.col + data = adjmat.data + i = torch.LongTensor(np.array([row, col])) + v = torch.from_numpy(data).float() + adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape) + return adjmat + +def get_graph_params(filename, nsize=1): + """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" + data = np.load(filename, encoding='latin1', allow_pickle=True) # np.load(filename, encoding='latin1') + A = data['A'] + U = data['U'] + D = data['D'] + U, D = scipy_to_pytorch(A, U, D) + A = [adjmat_sparse(a, nsize=nsize) for a in A] + return A, U, D + +class Mesh(object): + """Mesh object that is used for handling certain graph operations.""" + def __init__(self, filename='data/mesh_downsampling.npz', + num_downsampling=1, nsize=1, body_model=None, device=torch.device('cuda')): + self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) + self._A = [a.to(device) for a in self._A] + self._U = [u.to(device) for u in self._U] + self._D = [d.to(device) for d in self._D] + self.num_downsampling = num_downsampling + + # load template vertices from SMPL and normalize them + if body_model is None: + smpl = SMPL() + else: + smpl = body_model + ref_vertices = smpl.v_template + center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] + ref_vertices -= center + ref_vertices /= ref_vertices.abs().max().item() + + self._ref_vertices = ref_vertices.to(device) + self.faces = smpl.faces.int().to(device) + + @property + def adjmat(self): + """Return the graph adjacency matrix at the specified subsampling level.""" + return self._A[self.num_downsampling].float() + + @property + def ref_vertices(self): + """Return the template vertices at the specified subsampling level.""" + ref_vertices = self._ref_vertices + for i in range(self.num_downsampling): + ref_vertices = torch.spmm(self._D[i], ref_vertices) + return ref_vertices + + def get_ref_vertices(self, n_downsample): + """Return the template vertices at any desired subsampling level.""" + ref_vertices = self._ref_vertices + for i in range(n_downsample): + ref_vertices = torch.spmm(self._D[i], ref_vertices) + return ref_vertices + + def downsample(self, x, n1=0, n2=None): + """Downsample mesh.""" + if n2 is None: + n2 = self.num_downsampling + if x.ndimension() < 3: + for i in range(n1, n2): + x = spmm(self._D[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in range(n1, n2): + y = spmm(self._D[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x + + def upsample(self, x, n1=1, n2=0): + """Upsample mesh.""" + if x.ndimension() < 3: + for i in reversed(range(n2, n1)): + x = spmm(self._U[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in reversed(range(n2, n1)): + y = spmm(self._U[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x \ No newline at end of file diff --git a/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py b/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b38ba2bd9a058017246d15ed145cbc0408ba690a --- /dev/null +++ b/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py @@ -0,0 +1,245 @@ + +""" +code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py +shortest.py +---------------- +Given a mesh and two vertex indices find the shortest path +between the two vertices while only traveling along edges +of the mesh. +""" + +# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py + + +import os +import sys +import glob +import csv +import json +import shutil +import tqdm +import numpy as np +import pickle as pkl +import trimesh +import networkx as nx + + + + + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + + +def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'): + vert_dists = np.load(root_out_path + filename) + return vert_dists + + +def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False): + # root_out_path = ROOT_OUT_PATH + ''' + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + ''' + # path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + verts = my_mesh.vertices + faces = my_mesh.faces + # edges without duplication + edges = my_mesh.edges_unique + # the actual length of each unique edge + length = my_mesh.edges_unique_length + # create the graph with edge attributes for length (option A) + # g = nx.Graph() + # for edge, L in zip(edges, length): g.add_edge(*edge, length=L) + # you can create the graph with from_edgelist and + # a list comprehension (option B) + ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)]) + # calculate the distances between all vertex pairs + if calc_dist_mat: + # calculate distances between all possible vertex pairs + # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length') + # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length') + dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra')) + vertex_distances = np.zeros((n_verts_smal, n_verts_smal)) + for ind_v0 in range(n_verts_smal): + print(ind_v0) + for ind_v1 in range(ind_v0, n_verts_smal): + vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1] + vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1] + # save those distances + np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances) + vert_dists = vertex_distances + else: + vert_dists = np.load(root_out_path + 'all_vertex_distances.npy') + return ga, vert_dists + + +def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None): + # input: + # root_out_path_vis = ROOT_OUT_PATH + # img_v12_dir = IMG_V12_DIR + # name = images_with_gc_labelled[ind_img] + # gc_info_raw = gc_dict['bite/' + name] + # output: + # vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist] + n_verts_smal = 3889 + gc_vertices = [] + gc_info_np = np.zeros((n_verts_smal)) + for ind_v in gc_info_raw: + if ind_v < n_verts_smal: + gc_vertices.append(ind_v) + gc_info_np[ind_v] = 1 + # save a visualization of those annotations + if root_out_path_vis is not None: + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + if img_v12_dir is not None and root_out_path_vis is not None: + vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh.visual.vertex_colors = vert_colors + my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj')) + img_path = img_v12_dir + name + shutil.copy(img_path, root_out_path_vis + name) + # calculate for each vertex the distance to the closest element of the other group + non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices)) + print('vertices in contact: ' + str(len(gc_vertices))) + print('vertices without contact: ' + str(len(non_gc_vertices))) + vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist + vertex_overview[:, 0] = gc_info_np + # loop through all contact vertices + for ind_v in gc_vertices: + min_length = 100 + for ind_v_ps in non_gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + # loop through all non-contact vertices + for ind_v in non_gc_vertices: + min_length = 100 + for ind_v_ps in gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + if root_out_path_vis is not None: + # save a colored mesh + my_mesh_dists = my_mesh.copy() + scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max() + scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max() + vert_col = np.zeros((n_verts_smal, 3)) + vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green + vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red + my_mesh_dists.visual.vertex_colors = np.uint8(vert_col) + my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj')) + return vertex_overview + + + + + + + + +def main(): + + ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/' + ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/' + IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/' + # ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/' + ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/' + ROOT_OUT_PATH_VIS = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/vis/' + ROOT_OUT_PATH_DISTSGCNONGC = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/vertex_distances_gc_nongc/' + ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/' + + # load all vertex distances + path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + verts = my_mesh.vertices + faces = my_mesh.faces + # vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False) + vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy') + + + + + all_keys = [] + gc_dict = {} + # data/stanext_related_data/ground_contact_annotations/stage3/main_partA1667_20221021_140108.csv + # for csv_file in ['main_partA500_20221018_131139.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']: + # for csv_file in ['main_partA1667_20221021_140108.csv', 'main_partA500_20221018_131139.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']: + for csv_file in ['main_partA1667_20221021_140108.csv', 'main_partA500_20221018_131139.csv', 'main_partB20221023_150926.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']: + # load all ground contact annotations + gc_annot_csv = ROOT_PATH_ANNOT + csv_file # 'my_gcannotations_qualification.csv' + gc_row_list = read_csv(gc_annot_csv) + for ind_row in range(len(gc_row_list)): + json_acceptable_string = (gc_row_list[ind_row]['vertices']).replace("'", "\"") + gc_dict_temp = json.loads(json_acceptable_string) + all_keys.extend(gc_dict_temp.keys()) + gc_dict.update(gc_dict_temp) + print(len(gc_dict.keys())) + + print('number of labeled images: ' + str(len(gc_dict.keys()))) # WHY IS THIS ONLY 699? + + import pdb; pdb.set_trace() + + + # prepare and save contact annotations including distances + vertex_overview_dict = {} + for ind_img, name_ingcdict in enumerate(gc_dict.keys()): # range(len(gc_dict.keys())): + name = name_ingcdict.split('bite/')[1] + # name = images_with_gc_labelled[ind_img] + print('work on image ' + str(ind_img) + ': ' + name) + # gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact + gc_info_raw = gc_dict[name_ingcdict] # a list with all vertex numbers that are in ground contact + + if not os.path.exists(ROOT_OUT_PATH_VIS + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_VIS + name.split('/')[0]) + if not os.path.exists(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]) + + vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH_VIS, verts=verts, faces=faces, img_v12_dir=None) + np.save(ROOT_OUT_PATH_DISTSGCNONGC + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview) + + vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw} + + + + + + # import pdb; pdb.set_trace() + + with open(ROOT_OUT_PATH + 'gc_annots_overview_stage3complete_withtraintestval_xx.pkl', 'wb') as fp: + pkl.dump(vertex_overview_dict, fp) + + + + + + + + + + + + + +if __name__ == "__main__": + main() + + + + + + + diff --git a/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py b/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py new file mode 100644 index 0000000000000000000000000000000000000000..9518dcec3d5274f69f6b5880584bbdc36e58cd22 --- /dev/null +++ b/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py @@ -0,0 +1,213 @@ + +""" +code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py +shortest.py +---------------- +Given a mesh and two vertex indices find the shortest path +between the two vertices while only traveling along edges +of the mesh. +""" + +# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py + + +import os +import sys +import glob +import csv +import json +import shutil +import tqdm +import numpy as np +import pickle as pkl +import trimesh +import networkx as nx + + + + + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + + +def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'): + vert_dists = np.load(root_out_path + filename) + return vert_dists + + +def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False): + # root_out_path = ROOT_OUT_PATH + ''' + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + ''' + # path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + verts = my_mesh.vertices + faces = my_mesh.faces + # edges without duplication + edges = my_mesh.edges_unique + # the actual length of each unique edge + length = my_mesh.edges_unique_length + # create the graph with edge attributes for length (option A) + # g = nx.Graph() + # for edge, L in zip(edges, length): g.add_edge(*edge, length=L) + # you can create the graph with from_edgelist and + # a list comprehension (option B) + ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)]) + # calculate the distances between all vertex pairs + if calc_dist_mat: + # calculate distances between all possible vertex pairs + # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length') + # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length') + dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra')) + vertex_distances = np.zeros((n_verts_smal, n_verts_smal)) + for ind_v0 in range(n_verts_smal): + print(ind_v0) + for ind_v1 in range(ind_v0, n_verts_smal): + vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1] + vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1] + # save those distances + np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances) + vert_dists = vertex_distances + else: + vert_dists = np.load(root_out_path + 'all_vertex_distances.npy') + return ga, vert_dists + + +def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None): + # input: + # root_out_path_vis = ROOT_OUT_PATH + # img_v12_dir = IMG_V12_DIR + # name = images_with_gc_labelled[ind_img] + # gc_info_raw = gc_dict['bite/' + name] + # output: + # vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist] + n_verts_smal = 3889 + gc_vertices = [] + gc_info_np = np.zeros((n_verts_smal)) + for ind_v in gc_info_raw: + if ind_v < n_verts_smal: + gc_vertices.append(ind_v) + gc_info_np[ind_v] = 1 + # save a visualization of those annotations + if root_out_path_vis is not None: + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + if img_v12_dir is not None and root_out_path_vis is not None: + vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh.visual.vertex_colors = vert_colors + my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj')) + img_path = img_v12_dir + name + shutil.copy(img_path, root_out_path_vis + name) + # calculate for each vertex the distance to the closest element of the other group + non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices)) + print('vertices in contact: ' + str(len(gc_vertices))) + print('vertices without contact: ' + str(len(non_gc_vertices))) + vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist + vertex_overview[:, 0] = gc_info_np + # loop through all contact vertices + for ind_v in gc_vertices: + min_length = 100 + for ind_v_ps in non_gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + # loop through all non-contact vertices + for ind_v in non_gc_vertices: + min_length = 100 + for ind_v_ps in gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + if root_out_path_vis is not None: + # save a colored mesh + my_mesh_dists = my_mesh.copy() + scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max() + scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max() + vert_col = np.zeros((n_verts_smal, 3)) + vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green + vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red + my_mesh_dists.visual.vertex_colors = np.uint8(vert_col) + my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj')) + return vertex_overview + + + + + + + + + +def main(): + + ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/' + IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/' + # ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/' + ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/' + ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/' + + # load all vertex distances + path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + verts = my_mesh.vertices + faces = my_mesh.faces + # vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False) + vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy') + + # paw vertices: + # left and right is a bit different, but that is ok (we will anyways mirror data at training time) + right_front_paw = [3829,+3827,+3825,+3718,+3722,+3723,+3743,+3831,+3719,+3726,+3716,+3724,+3828,+3717,+3721,+3725,+3832,+3830,+3720,+3288,+3740,+3714,+3826,+3715,+3728,+3712,+3287,+3284,+3727,+3285,+3742,+3291,+3710,+3697,+3711,+3289,+3730,+3713,+3739,+3282,+3738,+3708,+3709,+3741,+3698,+3696,+3308,+3695,+3706,+3700,+3707,+3306,+3305,+3737,+3304,+3303,+3307,+3736,+3735,+3250,+3261,+3732,+3734,+3733,+3731,+3729,+3299,+3297,+3298,+3295,+3293,+3296,+3294,+3292,+3312,+3311,+3314,+3309,+3290,+3313,+3410,+3315,+3411,+3412,+3316,+3421,+3317,+3415,+3445,+3327,+3328,+3283,+3343,+3326,+3325,+3330,+3286,+3399,+3398,+3329,+3446,+3400,+3331,+3401,+3281,+3332,+3279,+3402,+3419,+3407,+3356,+3358,+3357,+3280,+3354,+3277,+3278,+3346,+3347,+3377,+3378,+3345,+3386,+3379,+3348,+3384,+3418,+3372,+3276,+3275,+3374,+3274,+3373,+3375,+3369,+3371,+3376,+3273,+3396,+3397,+3395,+3388,+3360,+3370,+3361,+3394,+3387,+3420,+3359,+3389,+3272,+3391,+3393,+3390,+3392,+3363,+3362,+3367,+3365,+3705,+3271,+3704,+3703,+3270,+3269,+3702,+3268,+3224,+3267,+3701,+3225,+3699,+3265,+3264,+3266,+3263,+3262,+3249,+3228,+3230,+3251,+3301,+3300,+3302,+3252] + right_back_paw = [3472,+3627,+3470,+3469,+3471,+3473,+3626,+3625,+3475,+3655,+3519,+3468,+3629,+3466,+3476,+3624,+3521,+3654,+3657,+3838,+3518,+3653,+3839,+3553,+3474,+3516,+3656,+3628,+3834,+3535,+3630,+3658,+3477,+3520,+3517,+3595,+3522,+3597,+3596,+3501,+3534,+3503,+3478,+3500,+3479,+3502,+3607,+3499,+3608,+3496,+3605,+3609,+3504,+3606,+3642,+3614,+3498,+3480,+3631,+3610,+3613,+3506,+3659,+3660,+3632,+3841,+3661,+3836,+3662,+3633,+3663,+3664,+3634,+3635,+3486,+3665,+3636,+3637,+3666,+3490,+3837,+3667,+3493,+3638,+3492,+3495,+3616,+3644,+3494,+3835,+3643,+3833,+3840,+3615,+3650,+3668,+3652,+3651,+3645,+3646,+3647,+3649,+3648,+3622,+3617,+3448,+3621,+3618,+3623,+3462,+3464,+3460,+3620,+3458,+3461,+3463,+3465,+3573,+3571,+3467,+3569,+3557,+3558,+3572,+3570,+3556,+3585,+3593,+3594,+3459,+3566,+3592,+3567,+3568,+3538,+3539,+3555,+3537,+3536,+3554,+3575,+3574,+3583,+3541,+3550,+3576,+3581,+3639,+3577,+3551,+3582,+3580,+3552,+3578,+3542,+3549,+3579,+3523,+3526,+3598,+3525,+3600,+3640,+3599,+3601,+3602,+3603,+3529,+3604,+3530,+3533,+3532,+3611,+3612,+3482,+3481,+3505,+3452,+3455,+3456,+3454,+3457,+3619,+3451,+3450,+3449,+3591,+3589,+3641,+3584,+3561,+3587,+3559,+3488,+3484,+3483] + left_front_paw = [1791,+1950,+1948,+1790,+1789,+1746,+1788,+1747,+1949,+1944,+1792,+1945,+1356,+1775,+1759,+1777,+1787,+1946,+1757,+1761,+1745,+1943,+1947,+1744,+1309,+1786,+1771,+1354,+1774,+1765,+1767,+1768,+1772,+1763,+1770,+1773,+1769,+1764,+1766,+1758,+1760,+1762,+1336,+1333,+1330,+1325,+1756,+1323,+1755,+1753,+1749,+1754,+1751,+1321,+1752,+1748,+1750,+1312,+1319,+1315,+1313,+1317,+1318,+1316,+1314,+1311,+1310,+1299,+1276,+1355,+1297,+1353,+1298,+1300,+1352,+1351,+1785,+1784,+1349,+1783,+1782,+1781,+1780,+1779,+1778,+1776,+1343,+1341,+1344,+1339,+1342,+1340,+1360,+1335,+1338,+1362,+1357,+1361,+1363,+1458,+1337,+1459,+1456,+1460,+1493,+1332,+1375,+1376,+1331,+1374,+1378,+1334,+1373,+1494,+1377,+1446,+1448,+1379,+1449,+1329,+1327,+1404,+1406,+1405,+1402,+1328,+1426,+1432,+1434,+1403,+1394,+1395,+1433,+1425,+1286,+1380,+1466,+1431,+1290,+1401,+1381,+1427,+1450,+1393,+1430,+1326,+1396,+1428,+1397,+1429,+1398,+1420,+1324,+1422,+1417,+1419,+1421,+1443,+1418,+1423,+1444,+1442,+1424,+1445,+1495,+1440,+1441,+1468,+1436,+1408,+1322,+1435,+1415,+1439,+1409,+1283,+1438,+1416,+1407,+1437,+1411,+1413,+1414,+1320,+1273,+1272,+1278,+1469,+1463,+1457,+1358,+1464,+1465,+1359,+1372,+1391,+1390,+1455,+1447,+1454,+1467,+1453,+1452,+1451,+1383,+1345,+1347,+1348,+1350,+1364,+1392,+1410,+1412] + left_back_paw = [1957,+1958,+1701,+1956,+1951,+1703,+1715,+1702,+1700,+1673,+1705,+1952,+1955,+1674,+1699,+1675,+1953,+1704,+1954,+1698,+1677,+1671,+1672,+1714,+1706,+1676,+1519,+1523,+1686,+1713,+1692,+1685,+1543,+1664,+1712,+1691,+1959,+1541,+1684,+1542,+1496,+1663,+1540,+1497,+1499,+1498,+1500,+1693,+1665,+1694,+1716,+1666,+1695,+1501,+1502,+1696,+1667,+1503,+1697,+1504,+1668,+1669,+1506,+1670,+1508,+1510,+1507,+1509,+1511,+1512,+1621,+1606,+1619,+1605,+1513,+1620,+1618,+1604,+1633,+1641,+1642,+1607,+1617,+1514,+1632,+1614,+1689,+1640,+1515,+1586,+1616,+1516,+1517,+1603,+1615,+1639,+1585,+1521,+1602,+1587,+1584,+1601,+1623,+1622,+1631,+1598,+1624,+1629,+1589,+1687,+1625,+1599,+1630,+1569,+1570,+1628,+1626,+1597,+1627,+1590,+1594,+1571,+1568,+1567,+1574,+1646,+1573,+1645,+1648,+1564,+1688,+1647,+1643,+1649,+1650,+1651,+1577,+1644,+1565,+1652,+1566,+1578,+1518,+1524,+1583,+1582,+1520,+1581,+1522,+1525,+1549,+1551,+1580,+1552,+1550,+1656,+1658,+1554,+1657,+1659,+1548,+1655,+1690,+1660,+1556,+1653,+1558,+1661,+1544,+1662,+1654,+1547,+1545,+1527,+1560,+1526,+1678,+1679,+1528,+1708,+1707,+1680,+1529,+1530,+1709,+1546,+1681,+1710,+1711,+1682,+1532,+1531,+1683,+1534,+1533,+1536,+1538,+1600,+1553] + + + all_contact_vertices = right_front_paw + right_back_paw + left_front_paw + left_back_paw + + name = 'all4pawsincontact.jpg' + print('work on 4paw images') + gc_info_raw = all_contact_vertices # a list with all vertex numbers that are in ground contact + + vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH, verts=verts, faces=faces, img_v12_dir=None) + np.save(ROOT_OUT_PATH + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview) + + vertex_overview_dict = {} + vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw} + with open(ROOT_OUT_PATH + 'gc_annots_overview_all4pawsincontact_xx.pkl', 'wb') as fp: + pkl.dump(vertex_overview_dict, fp) + + + + + + + + + + + +if __name__ == "__main__": + main() + + + + + + + diff --git a/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py b/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6f3cecd72806a61052099e50b0300542448bb4 --- /dev/null +++ b/src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py @@ -0,0 +1,317 @@ + +""" +code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py +shortest.py +---------------- +Given a mesh and two vertex indices find the shortest path +between the two vertices while only traveling along edges +of the mesh. +""" + +# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py + + +import os +import sys +import glob +import csv +import json +import shutil +import tqdm +import numpy as np +import pickle as pkl +import trimesh +import networkx as nx + + + + + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + + +def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'): + vert_dists = np.load(root_out_path + filename) + return vert_dists + + +def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False): + # root_out_path = ROOT_OUT_PATH + ''' + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + ''' + # path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + verts = my_mesh.vertices + faces = my_mesh.faces + # edges without duplication + edges = my_mesh.edges_unique + # the actual length of each unique edge + length = my_mesh.edges_unique_length + # create the graph with edge attributes for length (option A) + # g = nx.Graph() + # for edge, L in zip(edges, length): g.add_edge(*edge, length=L) + # you can create the graph with from_edgelist and + # a list comprehension (option B) + ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)]) + # calculate the distances between all vertex pairs + if calc_dist_mat: + # calculate distances between all possible vertex pairs + # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length') + # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length') + dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra')) + vertex_distances = np.zeros((n_verts_smal, n_verts_smal)) + for ind_v0 in range(n_verts_smal): + print(ind_v0) + for ind_v1 in range(ind_v0, n_verts_smal): + vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1] + vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1] + # save those distances + np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances) + vert_dists = vertex_distances + else: + vert_dists = np.load(root_out_path + 'all_vertex_distances.npy') + return ga, vert_dists + + +def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None): + # input: + # root_out_path_vis = ROOT_OUT_PATH + # img_v12_dir = IMG_V12_DIR + # name = images_with_gc_labelled[ind_img] + # gc_info_raw = gc_dict['bite/' + name] + # output: + # vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist] + n_verts_smal = 3889 + gc_vertices = [] + gc_info_np = np.zeros((n_verts_smal)) + for ind_v in gc_info_raw: + if ind_v < n_verts_smal: + gc_vertices.append(ind_v) + gc_info_np[ind_v] = 1 + # save a visualization of those annotations + if root_out_path_vis is not None: + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + if img_v12_dir is not None and root_out_path_vis is not None: + vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh.visual.vertex_colors = vert_colors + my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj')) + img_path = img_v12_dir + name + shutil.copy(img_path, root_out_path_vis + name) + # calculate for each vertex the distance to the closest element of the other group + non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices)) + print('vertices in contact: ' + str(len(gc_vertices))) + print('vertices without contact: ' + str(len(non_gc_vertices))) + vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist + vertex_overview[:, 0] = gc_info_np + # loop through all contact vertices + for ind_v in gc_vertices: + min_length = 100 + for ind_v_ps in non_gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + # loop through all non-contact vertices + for ind_v in non_gc_vertices: + min_length = 100 + for ind_v_ps in gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + if root_out_path_vis is not None: + # save a colored mesh + my_mesh_dists = my_mesh.copy() + scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max() + scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max() + vert_col = np.zeros((n_verts_smal, 3)) + vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green + vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red + my_mesh_dists.visual.vertex_colors = np.uint8(vert_col) + my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj')) + return vertex_overview + + +def summarize_results_stage2b(row_list, display_worker_performance=False): + # four catch trials are included in every batch + annot_n02088466_3184 = {'paw_rb': 0, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 1, 'additional_part': 0, 'no_contact': 0} + annot_n02100583_9922 = {'paw_rb': 1, 'paw_rf': 0, 'paw_lb': 0, 'paw_lf': 0, 'additional_part': 0, 'no_contact': 0} + annot_n02105056_2798 = {'paw_rb': 1, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 1, 'additional_part': 1, 'no_contact': 0} + annot_n02091831_2288 = {'paw_rb': 0, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 0, 'additional_part': 0, 'no_contact': 0} + all_comments = [] + all_annotations = {} + for row in row_list: + all_comments.append(row['Answer.submitComments']) + worker_id = row['WorkerId'] + if display_worker_performance: + print('----------------------------------------------------------------------------------------------') + print('Worker ID: ' + worker_id) + n_wrong = 0 + n_correct = 0 + for ind in range(0, len(row['Answer.submitValuesNotSure'].split(';')) - 1): + input_image = (row['Input.images'].split(';')[ind]).split('StanExtV12_Images/')[-1] + paw_rb = row['Answer.submitValuesRightBack'].split(';')[ind] + paw_rf = row['Answer.submitValuesRightFront'].split(';')[ind] + paw_lb = row['Answer.submitValuesLeftBack'].split(';')[ind] + paw_lf = row['Answer.submitValuesLeftFront'].split(';')[ind] + addpart = row['Answer.submitValuesAdditional'].split(';')[ind] + no_contact = row['Answer.submitValuesNoContact'].split(';')[ind] + unsure = row['Answer.submitValuesNotSure'].split(';')[ind] + annot = {'paw_rb': paw_rb, 'paw_rf': paw_rf, 'paw_lb': paw_lb, 'paw_lf': paw_lf, + 'additional_part': addpart, 'no_contact': no_contact, 'not_sure': unsure, + 'worker_id': worker_id} # , 'input_image': input_image} + if ind == 0: + gt = annot_n02088466_3184 + elif ind == 1: + gt = annot_n02105056_2798 + elif ind == 2: + gt = annot_n02091831_2288 + elif ind == 3: + gt = annot_n02100583_9922 + else: + pass + if ind < 4: + for key in gt.keys(): + if str(annot[key]) == str(gt[key]): + n_correct += 1 + else: + if display_worker_performance: + print(input_image) + print(key + ':[ expected: ' + str(gt[key]) + ' predicted: ' + str(annot[key]) + ' ]') + n_wrong += 1 + else: + all_annotations[input_image] = annot + if display_worker_performance: + print('n_correct: ' + str(n_correct)) + print('n_wrong: ' + str(n_wrong)) + return all_annotations, all_comments + + + + + + +def main(): + + ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/' + ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/' + IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/' + # ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/' + ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/' + ROOT_OUT_PATH_VIS = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/vis/' + ROOT_OUT_PATH_DISTSGCNONGC = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/vertex_distances_gc_nongc/' + ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/' + + # load all vertex distances + path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + verts = my_mesh.vertices + faces = my_mesh.faces + # vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False) + vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy') + + + + + + # paw vertices: + # left and right is a bit different, but that is ok (we will anyways mirror data at training time) + right_front_paw = [3829,+3827,+3825,+3718,+3722,+3723,+3743,+3831,+3719,+3726,+3716,+3724,+3828,+3717,+3721,+3725,+3832,+3830,+3720,+3288,+3740,+3714,+3826,+3715,+3728,+3712,+3287,+3284,+3727,+3285,+3742,+3291,+3710,+3697,+3711,+3289,+3730,+3713,+3739,+3282,+3738,+3708,+3709,+3741,+3698,+3696,+3308,+3695,+3706,+3700,+3707,+3306,+3305,+3737,+3304,+3303,+3307,+3736,+3735,+3250,+3261,+3732,+3734,+3733,+3731,+3729,+3299,+3297,+3298,+3295,+3293,+3296,+3294,+3292,+3312,+3311,+3314,+3309,+3290,+3313,+3410,+3315,+3411,+3412,+3316,+3421,+3317,+3415,+3445,+3327,+3328,+3283,+3343,+3326,+3325,+3330,+3286,+3399,+3398,+3329,+3446,+3400,+3331,+3401,+3281,+3332,+3279,+3402,+3419,+3407,+3356,+3358,+3357,+3280,+3354,+3277,+3278,+3346,+3347,+3377,+3378,+3345,+3386,+3379,+3348,+3384,+3418,+3372,+3276,+3275,+3374,+3274,+3373,+3375,+3369,+3371,+3376,+3273,+3396,+3397,+3395,+3388,+3360,+3370,+3361,+3394,+3387,+3420,+3359,+3389,+3272,+3391,+3393,+3390,+3392,+3363,+3362,+3367,+3365,+3705,+3271,+3704,+3703,+3270,+3269,+3702,+3268,+3224,+3267,+3701,+3225,+3699,+3265,+3264,+3266,+3263,+3262,+3249,+3228,+3230,+3251,+3301,+3300,+3302,+3252] + right_back_paw = [3472,+3627,+3470,+3469,+3471,+3473,+3626,+3625,+3475,+3655,+3519,+3468,+3629,+3466,+3476,+3624,+3521,+3654,+3657,+3838,+3518,+3653,+3839,+3553,+3474,+3516,+3656,+3628,+3834,+3535,+3630,+3658,+3477,+3520,+3517,+3595,+3522,+3597,+3596,+3501,+3534,+3503,+3478,+3500,+3479,+3502,+3607,+3499,+3608,+3496,+3605,+3609,+3504,+3606,+3642,+3614,+3498,+3480,+3631,+3610,+3613,+3506,+3659,+3660,+3632,+3841,+3661,+3836,+3662,+3633,+3663,+3664,+3634,+3635,+3486,+3665,+3636,+3637,+3666,+3490,+3837,+3667,+3493,+3638,+3492,+3495,+3616,+3644,+3494,+3835,+3643,+3833,+3840,+3615,+3650,+3668,+3652,+3651,+3645,+3646,+3647,+3649,+3648,+3622,+3617,+3448,+3621,+3618,+3623,+3462,+3464,+3460,+3620,+3458,+3461,+3463,+3465,+3573,+3571,+3467,+3569,+3557,+3558,+3572,+3570,+3556,+3585,+3593,+3594,+3459,+3566,+3592,+3567,+3568,+3538,+3539,+3555,+3537,+3536,+3554,+3575,+3574,+3583,+3541,+3550,+3576,+3581,+3639,+3577,+3551,+3582,+3580,+3552,+3578,+3542,+3549,+3579,+3523,+3526,+3598,+3525,+3600,+3640,+3599,+3601,+3602,+3603,+3529,+3604,+3530,+3533,+3532,+3611,+3612,+3482,+3481,+3505,+3452,+3455,+3456,+3454,+3457,+3619,+3451,+3450,+3449,+3591,+3589,+3641,+3584,+3561,+3587,+3559,+3488,+3484,+3483] + left_front_paw = [1791,+1950,+1948,+1790,+1789,+1746,+1788,+1747,+1949,+1944,+1792,+1945,+1356,+1775,+1759,+1777,+1787,+1946,+1757,+1761,+1745,+1943,+1947,+1744,+1309,+1786,+1771,+1354,+1774,+1765,+1767,+1768,+1772,+1763,+1770,+1773,+1769,+1764,+1766,+1758,+1760,+1762,+1336,+1333,+1330,+1325,+1756,+1323,+1755,+1753,+1749,+1754,+1751,+1321,+1752,+1748,+1750,+1312,+1319,+1315,+1313,+1317,+1318,+1316,+1314,+1311,+1310,+1299,+1276,+1355,+1297,+1353,+1298,+1300,+1352,+1351,+1785,+1784,+1349,+1783,+1782,+1781,+1780,+1779,+1778,+1776,+1343,+1341,+1344,+1339,+1342,+1340,+1360,+1335,+1338,+1362,+1357,+1361,+1363,+1458,+1337,+1459,+1456,+1460,+1493,+1332,+1375,+1376,+1331,+1374,+1378,+1334,+1373,+1494,+1377,+1446,+1448,+1379,+1449,+1329,+1327,+1404,+1406,+1405,+1402,+1328,+1426,+1432,+1434,+1403,+1394,+1395,+1433,+1425,+1286,+1380,+1466,+1431,+1290,+1401,+1381,+1427,+1450,+1393,+1430,+1326,+1396,+1428,+1397,+1429,+1398,+1420,+1324,+1422,+1417,+1419,+1421,+1443,+1418,+1423,+1444,+1442,+1424,+1445,+1495,+1440,+1441,+1468,+1436,+1408,+1322,+1435,+1415,+1439,+1409,+1283,+1438,+1416,+1407,+1437,+1411,+1413,+1414,+1320,+1273,+1272,+1278,+1469,+1463,+1457,+1358,+1464,+1465,+1359,+1372,+1391,+1390,+1455,+1447,+1454,+1467,+1453,+1452,+1451,+1383,+1345,+1347,+1348,+1350,+1364,+1392,+1410,+1412] + left_back_paw = [1957,+1958,+1701,+1956,+1951,+1703,+1715,+1702,+1700,+1673,+1705,+1952,+1955,+1674,+1699,+1675,+1953,+1704,+1954,+1698,+1677,+1671,+1672,+1714,+1706,+1676,+1519,+1523,+1686,+1713,+1692,+1685,+1543,+1664,+1712,+1691,+1959,+1541,+1684,+1542,+1496,+1663,+1540,+1497,+1499,+1498,+1500,+1693,+1665,+1694,+1716,+1666,+1695,+1501,+1502,+1696,+1667,+1503,+1697,+1504,+1668,+1669,+1506,+1670,+1508,+1510,+1507,+1509,+1511,+1512,+1621,+1606,+1619,+1605,+1513,+1620,+1618,+1604,+1633,+1641,+1642,+1607,+1617,+1514,+1632,+1614,+1689,+1640,+1515,+1586,+1616,+1516,+1517,+1603,+1615,+1639,+1585,+1521,+1602,+1587,+1584,+1601,+1623,+1622,+1631,+1598,+1624,+1629,+1589,+1687,+1625,+1599,+1630,+1569,+1570,+1628,+1626,+1597,+1627,+1590,+1594,+1571,+1568,+1567,+1574,+1646,+1573,+1645,+1648,+1564,+1688,+1647,+1643,+1649,+1650,+1651,+1577,+1644,+1565,+1652,+1566,+1578,+1518,+1524,+1583,+1582,+1520,+1581,+1522,+1525,+1549,+1551,+1580,+1552,+1550,+1656,+1658,+1554,+1657,+1659,+1548,+1655,+1690,+1660,+1556,+1653,+1558,+1661,+1544,+1662,+1654,+1547,+1545,+1527,+1560,+1526,+1678,+1679,+1528,+1708,+1707,+1680,+1529,+1530,+1709,+1546,+1681,+1710,+1711,+1682,+1532,+1531,+1683,+1534,+1533,+1536,+1538,+1600,+1553] + + + + + + all_keys = [] + gc_dict = {} + vertex_overview_nocontact = {} + # data/stanext_related_data/ground_contact_annotations/stage3/main_partA1667_20221021_140108.csv + for csv_file in ['Stage2b_finalResults.csv']: + # load all ground contact annotations + gc_annot_csv = ROOT_PATH_ANNOT + csv_file # 'my_gcannotations_qualification.csv' + gc_row_list = read_csv(gc_annot_csv) + all_annotations, all_comments = summarize_results_stage2b(gc_row_list, display_worker_performance=False) + for key, value in all_annotations.items(): + if value['not_sure'] == '0': + if value['no_contact'] == '1': + vertex_overview_nocontact[key.split('.')[0]] = {'gc_vertdists_overview': 'no contact', 'gc_index_list': None} + else: + all_contact_vertices = [] + if value['paw_rf'] == '1': + all_contact_vertices.extend(right_front_paw) + if value['paw_rb'] == '1': + all_contact_vertices.extend(right_back_paw) + if value['paw_lf'] == '1': + all_contact_vertices.extend(left_front_paw) + if value['paw_lb'] == '1': + all_contact_vertices.extend(left_back_paw) + gc_dict[key] = all_contact_vertices + print('number of labeled images: ' + str(len(gc_dict.keys()))) + print('number of images without contact: ' + str(len(vertex_overview_nocontact.keys()))) + + # prepare and save contact annotations including distances + vertex_overview_dict = {} + for ind_img, name_ingcdict in enumerate(gc_dict.keys()): # range(len(gc_dict.keys())): + name = name_ingcdict # name_ingcdict.split('bite/')[1] + # name = images_with_gc_labelled[ind_img] + print('work on image ' + str(ind_img) + ': ' + name) + # gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact + gc_info_raw = gc_dict[name_ingcdict] # a list with all vertex numbers that are in ground contact + + if not os.path.exists(ROOT_OUT_PATH_VIS + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_VIS + name.split('/')[0]) + if not os.path.exists(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]) + + vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH_VIS, verts=verts, faces=faces, img_v12_dir=None) + np.save(ROOT_OUT_PATH_DISTSGCNONGC + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview) + + vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw} + + + + + + # import pdb; pdb.set_trace() + + with open(ROOT_OUT_PATH + 'gc_annots_overview_stage2b_contact_complete_xx.pkl', 'wb') as fp: + pkl.dump(vertex_overview_dict, fp) + + with open(ROOT_OUT_PATH + 'gc_annots_overview_stage2b_nocontact_complete_xx.pkl', 'wb') as fp: + pkl.dump(vertex_overview_nocontact, fp) + + + + + + + + + + + +if __name__ == "__main__": + main() + + + + + + + diff --git a/src/graph_networks/losses_for_vertex_wise_predictions/example_calculate_distance_between_points_on_mesh.py b/src/graph_networks/losses_for_vertex_wise_predictions/example_calculate_distance_between_points_on_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..7c15d4bc85bb48aa9fed4613e7f0f7298bec4a21 --- /dev/null +++ b/src/graph_networks/losses_for_vertex_wise_predictions/example_calculate_distance_between_points_on_mesh.py @@ -0,0 +1,173 @@ + +""" +code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py +shortest.py +---------------- +Given a mesh and two vertex indices find the shortest path +between the two vertices while only traveling along edges +of the mesh. +""" + +# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py + + +import os +import sys +import glob +import csv +import json +import shutil +import numpy as np +import trimesh +import networkx as nx + + + +ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/' +ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/' +STAN_V12_ROOT_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/' +IMG_V12_DIR = STAN_V12_ROOT_DIR + 'StanExtV12_Images/' +ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/' + + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + +images_with_gc_labelled = ['n02093991-Irish_terrier/n02093991_2874.jpg', + 'n02093754-Border_terrier/n02093754_1062.jpg', + 'n02092339-Weimaraner/n02092339_1672.jpg', + 'n02096177-cairn/n02096177_4916.jpg', + 'n02110185-Siberian_husky/n02110185_725.jpg', + 'n02110806-basenji/n02110806_761.jpg', + 'n02094433-Yorkshire_terrier/n02094433_2474.jpg', + 'n02097474-Tibetan_terrier/n02097474_8796.jpg', + 'n02099601-golden_retriever/n02099601_2495.jpg'] + + +# ----- PART 1: load all ground contact annotations +gc_annot_csv = ROOT_PATH_ANNOT + 'my_gcannotations_qualification.csv' +gc_row_list = read_csv(gc_annot_csv) +json_acceptable_string = (gc_row_list[0]['vertices']).replace("'", "\"") +gc_dict = json.loads(json_acceptable_string) + + +# ----- PART 2: load and prepare the mesh +''' +from smal_pytorch.smal_model.smal_torch_new import SMAL +smal = SMAL() +verts = smal.v_template.detach().cpu().numpy() +faces = smal.faces.detach().cpu().numpy() +''' +path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' +my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) +verts = my_mesh.vertices +faces = my_mesh.faces +# edges without duplication +edges = my_mesh.edges_unique +# the actual length of each unique edge +length = my_mesh.edges_unique_length +# create the graph with edge attributes for length (option A) +# g = nx.Graph() +# for edge, L in zip(edges, length): g.add_edge(*edge, length=L) +# you can create the graph with from_edgelist and +# a list comprehension (option B) +ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)]) + + +# ----- PART 3: calculate the distances between all vertex pairs +calc_dist_mat = False +if calc_dist_mat: + # calculate distances between all possible vertex pairs + # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length') + # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length') + dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra')) + vertex_distances = np.zeros((n_verts_smal, n_verts_smal)) + for ind_v0 in range(n_verts_smal): + print(ind_v0) + for ind_v1 in range(ind_v0, n_verts_smal): + vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1] + vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1] + # save those distances + np.save(ROOT_OUT_PATH + 'all_vertex_distances.npy', vertex_distances) + vert_dists = vertex_distances +else: + vert_dists = np.load(ROOT_OUT_PATH + 'all_vertex_distances.npy') + + +# ----- PART 4: prepare contact annotation +n_verts_smal = 3889 +for ind_img in range(len(images_with_gc_labelled)): # range(len(gc_dict.keys())): + name = images_with_gc_labelled[ind_img] + print('work on image ' + name) + gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact + gc_vertices = [] + gc_info_np = np.zeros((n_verts_smal)) + for ind_v in gc_info_raw: + if ind_v < n_verts_smal: + gc_vertices.append(ind_v) + gc_info_np[ind_v] = 1 + # save a visualization of those annotations + vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + my_mesh.export(ROOT_OUT_PATH + (name.split('/')[1]).replace('.jpg', '_withgc.obj')) + img_path = IMG_V12_DIR + name + shutil.copy(img_path, ROOT_OUT_PATH + name.split('/')[1]) + + # ----- PART 5: calculate for each vertex the distance to the closest element of the other group + non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices)) + print('vertices in contact: ' + str(len(gc_vertices))) + print('vertices without contact: ' + str(len(non_gc_vertices))) + vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist + vertex_overview[:, 0] = gc_info_np + # loop through all contact vertices + for ind_v in gc_vertices: + min_length = 100 + for ind_v_ps in non_gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + # loop through all non-contact vertices + for ind_v in non_gc_vertices: + min_length = 100 + for ind_v_ps in gc_vertices: # possible solution + # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length') + # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length') + this_length = vert_dists[ind_v, ind_v_ps] + if this_length < min_length: + min_length = this_length + vertex_overview[ind_v, 1] = ind_v_ps + vertex_overview[ind_v, 2] = this_length + # save a colored mesh + my_mesh_dists = my_mesh.copy() + scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max() + scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max() + vert_col = np.zeros((n_verts_smal, 3)) + vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green + vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red + my_mesh_dists.visual.vertex_colors = np.uint8(vert_col) + my_mesh_dists.export(ROOT_OUT_PATH + (name.split('/')[1]).replace('.jpg', '_withgcdists.obj')) + + + + +import pdb; pdb.set_trace() + + + + + + + + + + + diff --git a/src/graph_networks/losses_for_vertex_wise_predictions/process_stage12_results.py b/src/graph_networks/losses_for_vertex_wise_predictions/process_stage12_results.py new file mode 100644 index 0000000000000000000000000000000000000000..224fbfa0b955a14c34c06d84a82ce6ca3ac8b348 --- /dev/null +++ b/src/graph_networks/losses_for_vertex_wise_predictions/process_stage12_results.py @@ -0,0 +1,322 @@ + +# see also (laptop): +# /home/nadine/Documents/PhD/icon_barc_project/AMT_ground_contact_studies/stages_1and2_together/evaluate_stages12_main_forstage2b_new.py +# +# python src/graph_networks/losses_for_vertex_wise_predictions/process_stage12_results.py +# + + + +import numpy as np +import os +import sys +import csv +import shutil +import pickle as pkl + +ROOT_path = '/home/nadine/Documents/PhD/icon_barc_project/AMT_ground_contact_studies/' +ROOT_path_images = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/' +ROOT_amt_image_list = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/amt_image_lists/' +ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/' + + +root_path_stage1 = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage1/' +root_path_stage2 = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2/' + +csv_file_stage1_pilot = root_path_stage1 + 'stage1_pilot_Batch_4841525_batch_results.csv' +csv_file_stage1_main = root_path_stage1 + 'stage1_main_stage1_Batch_4890079_batch_results.csv' +csv_file_stage2_pilot = root_path_stage2 + 'stage2_pilot_DogStage2PilotResults.csv' +csv_file_stage2_main = root_path_stage2 + 'stage2_main_Batch_4890110_batch_results.csv' + +full_amt_image_list = ROOT_amt_image_list + 'all_stanext_image_names_amt.txt' +train_amt_image_list = ROOT_amt_image_list + 'all_stanext_image_names_train.txt' +test_amt_image_list = ROOT_amt_image_list + 'all_stanext_image_names_test.txt' +val_amt_image_list = ROOT_amt_image_list + 'all_stanext_image_names_val.txt' + +experiment_name = 'stage_2b_image_paths' +AMT_images_root_path = 'https://dogvisground.s3.eu-central-1.amazonaws.com/StanExtV12_Images/' # n02085620-Chihuahua/n02085620_10074.jpg' +# out_folder = '/home/nadine/Documents/PhD/icon_barc_project/AMT_ground_contact_studies/stage_2b/stage2b_html_and_csv_files/' +# out_folder_imgs = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/' +# csv_out_path_pilot = out_folder + experiment_name + '_pilot_bs22.csv' +# csv_out_path_main = out_folder + experiment_name + '_main_bs22.csv' + + + + + +pose_dict = {'1':'Standing still, all four paws fully on the ground', + '2':'Standing still, at least one paw lifted (if you are in doubt if the paw is on the ground or not, choose this option)', + '3':'Walking or trotting (walk, amble, pace, trot)', + '4':'Running (only canter, gallup, run)', + '5':'Sitting, symmetrical legs', + '6':'Sitting, complicated pose (every sitting pose with asymmetrical leg position)', + '7':'lying, symmetrical legs (and not lying on the side)', + '8':'lying, complicated pose (every lying pose with asymmetrical leg position)', + '9':'Jumping, not touching the ground', + '10':'Jumping or about to jump, touching the ground', + '11':'On hind legs (standing or walking or sitting)', + '12':'Downward facing dog: standing on back legs/paws and bending over front leg', + '13':'Other poses: being carried by a human, ...', + '14':'I can not see the pose (please comment why: hidden, hairy, legs cut off, ...)'} + +pose_dict_abbrev = {'1':'standing_4paws', + '2':'standing_fewpaws', + '3':'walking', + '4':'running', + '5':'sitting_sym', + '6':'sitting_comp', + '7':'lying_sym', + '8':'lying_comp', + '9':'jumping_nottouching', + '10':'jumping_touching', + '11':'onhindlegs', + '12':'downwardfacingdog', + '13':'otherpose', + '14':'cantsee'} + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + +def add_stage2_to_result_dict(row_list_stage2, result_info_dict): + for ind_worker in range(len(row_list_stage2)): + # print('------------------------------------') + image_names = row_list_stage2[ind_worker]['Input.images'].split(';') + all_answers_comment = row_list_stage2[ind_worker]['Answer.submitComments'].split(';') + all_answers_dogpose = row_list_stage2[ind_worker]['Answer.submitValues'].split(';') + for ind in range(len(image_names)): + if 'Qualification_Tutorial_Images' in image_names[ind]: + # print('skip tutorial images') + pass + else: + img_subf = image_names[ind].split('/')[-2] + img_name = image_names[ind].split('/')[-1] + img_name_key = img_subf + '/' + img_name + img_path = ROOT_path_images + img_subf + '/' + img_name + this_img = {'img_name': img_name, + 'img_subf': img_subf, + # 'img_path': img_path, + 'pose': pose_dict_abbrev[all_answers_dogpose[ind]], + 'comment_pose': all_answers_comment[ind]} + assert not (img_name_key in result_info_dict.keys()) + result_info_dict[img_name_key] = this_img + '''folder_name = pose_dict_abbrev[all_answers_dogpose[ind]] + img_name_out = img_name # 'indw' + str(ind_worker) + '_' + img_name + out_folder_this = out_folder_imgs + folder_name + '/' + img_name + shutil.copyfile(img_path, out_folder_this + img_name_out)''' + +def add_stage1_to_result_dict(row_list_stage1, result_info_dict): + for ind_worker in range(len(row_list_stage1)): + # print('------------------------------------') + image_names = row_list_stage1[ind_worker]['Input.images'].split(';') + all_answers_commentvis = row_list_stage1[ind_worker]['Answer.submitCommentsVisible'].split(';') + all_answers_vis = row_list_stage1[ind_worker]['Answer.submitValuesVisible'].split(';') # 1: visible, 2: not visible + all_answers_commentground = row_list_stage1[ind_worker]['Answer.submitCommentsGround'].split(';') + all_answers_ground = row_list_stage1[ind_worker]['Answer.submitValuesGround'].split(';') # 1: flat, 2: not flat + for ind in range(len(image_names)): + if len(image_names[ind].split('/')) < 2: + print('no more image in ind_worker ' + str(ind_worker)) + elif 'Qualification_Tutorial_Images' in image_names[ind]: + # print('skip tutorial images') + pass + else: + img_subf = image_names[ind].split('/')[-2] + img_name = image_names[ind].split('/')[-1] + img_name_key = img_subf + '/' + img_name + img_path = ROOT_path_images + img_subf + '/' + img_name + if all_answers_vis[ind] == '1': + vis = True + elif all_answers_vis[ind] == '2': + vis = False + else: + vis = None + # raise ValueError + if all_answers_ground[ind] == '1': + flat = True + elif all_answers_ground[ind] == '2': + flat = False + else: + flat = None + # raise ValueError + if img_name_key in result_info_dict.keys(): + result_info_dict[img_name_key]['is_vis'] = vis + result_info_dict[img_name_key]['comment_vis'] = all_answers_commentvis[ind] + result_info_dict[img_name_key]['is_flat'] = flat + result_info_dict[img_name_key]['comment_flat'] = all_answers_commentground[ind] + else: + print(img_path) + this_img = {'img_name': img_name, + 'img_subf': img_subf, + # 'img_path': img_path, + 'is_vis': vis, + 'comment_vis': all_answers_commentvis[ind], + 'is_flat': flat, + 'comment_flat': all_answers_commentground[ind]} + result_info_dict[img_name_key] = this_img + + + +# ------------------------------------------------------------------------------ + +''' +if not os.path.exists(out_folder_imgs): os.makedirs(out_folder_imgs) +for folder_name in pose_dict_abbrev.values(): + out_folder_this = out_folder_imgs + folder_name + if not os.path.exists(out_folder_this): os.makedirs(out_folder_this) +''' + + + +row_list_stage2_pilot = read_csv(csv_file_stage2_pilot) +row_list_stage1_pilot = read_csv(csv_file_stage1_pilot) +row_list_stage2_main = read_csv(csv_file_stage2_main) +row_list_stage1_main = read_csv(csv_file_stage1_main) + +result_info_dict = {} +add_stage2_to_result_dict(row_list_stage2_pilot, result_info_dict) +add_stage2_to_result_dict(row_list_stage2_main, result_info_dict) +add_stage1_to_result_dict(row_list_stage1_pilot, result_info_dict) +add_stage1_to_result_dict(row_list_stage1_main, result_info_dict) + + + +# initial image list: all_stanext_image_names_amt.txt +# (/home/nadine/Documents/PhD/icon_barc_project/AMT_ground_contact_studies/all_stanext_image_names_amt.txt) +# the initial image list did first contain randomly shuffeled {train + test} +# images and after that randomly shuffeled {val} images +# see also /is/cluster/work/nrueegg/icon_pifu_related/ICON/lib/ground_contact/create_gc_dataset/get_stanext_images_for_amt.py +# train and test: 6773 + 1703 = 8476 +# val: 4062 +with open(full_amt_image_list) as f: full_amt_lines = f.readlines() +with open(train_amt_image_list) as f: train_amt_lines = f.readlines() +with open(test_amt_image_list) as f: test_amt_lines = f.readlines() +with open(val_amt_image_list) as f: val_amt_lines = f.readlines() + +for ind_l, line in enumerate(train_amt_lines): + img_name_key = (line.split('/')[-2]) + '/' + (line.split('/')[-1]).split('\n')[0] + result_info_dict[img_name_key]['split'] = 'train' +for ind_l, line in enumerate(test_amt_lines): + img_name_key = (line.split('/')[-2]) + '/' + (line.split('/')[-1]).split('\n')[0] + result_info_dict[img_name_key]['split'] = 'test' +for ind_l, line in enumerate(val_amt_lines): + img_name_key = (line.split('/')[-2]) + '/' + (line.split('/')[-1]).split('\n')[0] + result_info_dict[img_name_key]['split'] = 'val' + + +# we have stage 2b labels for: +# constraint_vis = (res['is_vis'] in {True, None}) +# constraint_flat = (res['is_flat'] in {True, None}) +# constraint_pose = (res['pose'] in {'standing_fewpaws', 'walking', 'running', }) +# we have stage 3 labels for: +# constraint_vis = (res['is_vis'] in {True, None}) +# constraint_flat = (res['is_flat'] in {True, None}) +# constraint_pose = (res['pose'] in {'sitting_sym', 'sitting_comp', 'lying_sym', 'lying_comp', 'downwardfacingdog', 'otherpose', 'jumping_touching', 'onhindlegs'}) +# we have no labels for: +# constraint_pose = (res['pose'] in {'standing_4paws', 'jumping_nottouching', 'cantsee'}) + + +with open(ROOT_OUT_PATH + 'gc_annots_categories_stages12_complete.pkl', 'wb') as fp: + pkl.dump(result_info_dict, fp) + + +import pdb; pdb.set_trace() + + + + + + + + + + + + + + + + + + + + + + + + + + + + +# ------------------------------------------------------------------------------------------------- + +''' +# sort the result images. +all_pose_names = [*pose_dict_abbrev.values()] +split_list = ['train', 'test', 'val', 'traintest'] +split_list_dict = {} +for split in split_list: + nimgs_pose_dict = {} + for pose_name in all_pose_names: + nimgs_pose_dict[pose_name] = 0 + images_to_label = [] + for ind_l, line in enumerate(full_amt_lines): + img_name = (line.split('/')[-1]).split('\n')[0] + res = result_info_dict[img_name] + if split == 'traintest': + constraint_split = (res['split'] == 'train') or (res['split'] == 'test') + else: + constraint_split = (res['split'] == split) # (res['split'] == 'train') + constraint_vis = (res['is_vis'] in {True, None}) + constraint_flat = (res['is_flat'] in {True, None}) + # constraint_pose = (res['pose'] in {'sitting_sym', 'sitting_comp', 'lying_sym', 'lying_comp', 'downwardfacingdog', 'otherpose', 'jumping_touching', 'onhindlegs'}) + constraint_pose = (res['pose'] in {'standing_fewpaws', 'walking', 'running', }) + + if constraint_split * constraint_vis * constraint_flat == True: + nimgs_pose_dict[res['pose']] += 1 + if constraint_pose: + images_to_label.append(line) + folder_name = 'imgsforstage2b_' + split # 'imgsforstage3_train' + out_folder_this = out_folder_imgs + folder_name + '/' + if not os.path.exists(out_folder_this): os.makedirs(out_folder_this) + shutil.copyfile(res['img_path'], out_folder_this + img_name) + print('------------------------------------------------------') + print(split) + print(nimgs_pose_dict) + print(len(images_to_label)) + split_list_dict[split] = {'nimgs_pose_dict': nimgs_pose_dict, + 'len(images_to_label)': len(images_to_label), + 'images_to_label': images_to_label} + + + +# create csv files: +traintest_list = split_list_dict['traintest']['images_to_label'] +val_list = split_list_dict['val']['images_to_label'] +complete_list = traintest_list + val_list + +all_lines_refined = [] +for line in complete_list: + all_lines_refined.append(line.split('\n')[0]) + +import pdb; pdb.set_trace() +''' + + + + + + + + + + + + + + + + diff --git a/src/graph_networks/losses_for_vertex_wise_predictions/visualization_calculate_distance_between_points_on_mesh.py b/src/graph_networks/losses_for_vertex_wise_predictions/visualization_calculate_distance_between_points_on_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0aaf392f81c60d100a6bbe78b8e3667df5d46b --- /dev/null +++ b/src/graph_networks/losses_for_vertex_wise_predictions/visualization_calculate_distance_between_points_on_mesh.py @@ -0,0 +1,73 @@ + +""" +code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py +shortest.py +---------------- +Given a mesh and two vertex indices find the shortest path +between the two vertices while only traveling along edges +of the mesh. +""" + +# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py + + + +import trimesh + +import networkx as nx + + + +ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/' +ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/' + +path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' + + +import pdb; pdb.set_trace() + + + +my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) + + +# edges without duplication +edges = my_mesh.edges_unique + +# the actual length of each unique edge +length = my_mesh.edges_unique_length + +# create the graph with edge attributes for length (option A) +# g = nx.Graph() +# for edge, L in zip(edges, length): g.add_edge(*edge, length=L) +# you can create the graph with from_edgelist and +# a list comprehension (option B) +ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)]) + +# arbitrary indices of mesh.vertices to test with +start = 0 +end = int(len(my_mesh.vertices) / 2.0) + +# run the shortest path query using length for edge weight +path = nx.shortest_path(ga, source=start, target=end, weight='length') + +# VISUALIZE RESULT +# make the sphere transparent-ish +my_mesh.visual.face_colors = [100, 100, 100, 100] +# Path3D with the path between the points +path_visual = trimesh.load_path(my_mesh.vertices[path]) +# visualizable two points +points_visual = trimesh.points.PointCloud(my_mesh.vertices[[start, end]]) + +# create a scene with the mesh, path, and points +my_scene = trimesh.Scene([points_visual, path_visual, my_mesh]) + +my_scene.export(ROOT_OUT_PATH + 'shortest_path.stl') + + +scene.show(smooth=False) + + + + + diff --git a/src/graph_networks/uniform_surface_sampling/create_remesh_template_for_uniform_smal_surface_sampling.py b/src/graph_networks/uniform_surface_sampling/create_remesh_template_for_uniform_smal_surface_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..236b62b2e935e6df4667a7eca02ca741c453f260 --- /dev/null +++ b/src/graph_networks/uniform_surface_sampling/create_remesh_template_for_uniform_smal_surface_sampling.py @@ -0,0 +1,57 @@ + +# python src/graph_networks/uniform_surface_sampling/create_remesh_template_for_uniform_smal_surface_sampling.py + +import numpy as np +import pyacvd +import pyvista as pv +import trimesh +import pickle as pkl + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from combined_model.helper import get_triangle_faces_from_pyvista_poly + + +ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/' +ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/' + +n_points = 25000 # 6000 # 4000 +name = 'my_smpl_39dogsnorm_Jr_4_dog_remesh25000' # 6000' # 'my_smpl_39dogsnorm_Jr_4_dog_remesh4000' + + +# load smal mesh (could also be loaded using SMAL class, this is just the SMAL dog template) +path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj' +my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True) +verts = my_mesh.vertices +faces = my_mesh.faces + +# read smal dog mesh with pyvista +mesh_pv = pv.read(path_mesh) +clus = pyacvd.Clustering(mesh_pv) + +# remesh the surface (see https://github.com/pyvista/pyacvd) +clus.subdivide(3) +clus.cluster(n_points) # clus.cluster(20000) +remesh = clus.create_mesh() +remesh_points_of_interest = np.asarray(remesh.points) + +# save the resulting mesh +# remesh.save(ROOT_OUT_PATH + name + '.ply') +remesh_triangle_faces = get_triangle_faces_from_pyvista_poly(remesh) +remesh_tri = trimesh.Trimesh(vertices=remesh_points_of_interest, faces=remesh_triangle_faces, process=False, maintain_order=True) +remesh_tri.export(ROOT_OUT_PATH + name + '.obj') + +# get barycentric coordinates +points_closest, dists_closest, faceid_closest = trimesh.proximity.closest_point(my_mesh, remesh_points_of_interest) +barys_closest = trimesh.triangles.points_to_barycentric(my_mesh.vertices[my_mesh.faces[faceid_closest]], points_closest) # , method='cramer') + +# test that we can get the vertex location of the remeshes mesh back +# -> similarly we will be able to calculate new vertex locations for a deformed smal mesh +verts_closest = np.einsum('ij,ijk->ik', barys_closest, my_mesh.vertices[my_mesh.faces[faceid_closest]]) + +# save all relevant (and more) information +remeshing_dict = {'remeshed_name': name + '.obj', 'is_symmetric': 'no', 'remeshed_verts': np.asarray(remesh.points), 'smal_mesh': path_mesh, 'points_closest': points_closest, 'dists_closest': dists_closest, 'faceid_closest': faceid_closest, 'barys_closest': barys_closest, 'string_to_get_new_point_locations_from_smal': 'verts_closest = np.einsum(ij,ijk->ik, barys_closest, smal_mesh.vertices[smal_mesh.faces[faceid_closest]])'} +with open(ROOT_OUT_PATH + name + '_info.pkl', 'wb') as f: + pkl.dump(remeshing_dict, f) + diff --git a/src/lifting_to_3d/inn_model_for_shape.py b/src/lifting_to_3d/inn_model_for_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab7c1f18ca603a20406092bdd7163e370d17023 --- /dev/null +++ b/src/lifting_to_3d/inn_model_for_shape.py @@ -0,0 +1,61 @@ + + +from torch import distributions +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributions import Normal +import numpy as np +import cv2 +import trimesh +from tqdm import tqdm +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import FrEIA.framework as Ff +import FrEIA.modules as Fm + + +class INNForShape(nn.Module): + def __init__(self, n_betas, n_betas_limbs, k_tot=2, betas_scale=1.0, betas_limbs_scale=0.1): + super(INNForShape, self).__init__() + self.n_betas = n_betas + self.n_betas_limbs = n_betas_limbs + self.n_dim = n_betas + n_betas_limbs + self.betas_scale = betas_scale + self.betas_limbs_scale = betas_limbs_scale + self.k_tot = 2 + self.model_inn = self.build_inn_network(self.n_dim, k_tot=self.k_tot) + + def subnet_fc(self, c_in, c_out): + subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(), + nn.Linear(64, 64), nn.ReLU(), + nn.Linear(64, c_out)) + return subnet + + def build_inn_network(self, n_input, k_tot=12, verbose=False): + coupling_block = Fm.RNVPCouplingBlock + nodes = [Ff.InputNode(n_input, name='input')] + for k in range(k_tot): + nodes.append(Ff.Node(nodes[-1], + coupling_block, + {'subnet_constructor':self.subnet_fc, 'clamp':2.0}, + name=F'coupling_{k}')) + nodes.append(Ff.Node(nodes[-1], + Fm.PermuteRandom, + {'seed':k}, + name=F'permute_{k}')) + nodes.append(Ff.OutputNode(nodes[-1], name='output')) + model = Ff.ReversibleGraphNet(nodes, verbose=verbose) + return model + + def forward(self, latent_rep): + shape, _ = self.model_inn(latent_rep, rev=False, jac=False) + betas = shape[:, :self.n_betas]*self.betas_scale + betas_limbs = shape[:, self.n_betas:]*self.betas_limbs_scale + return betas, betas_limbs + + def reverse(self, betas, betas_limbs): + shape = torch.cat((betas/self.betas_scale, betas_limbs/self.betas_limbs_scale), dim=1) + latent_rep, _ = self.model_inn(shape, rev=True, jac=False) + return latent_rep \ No newline at end of file diff --git a/src/lifting_to_3d/linear_model.py b/src/lifting_to_3d/linear_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c11266acefcb6bbecd8a748a44cb4915ef4da4b9 --- /dev/null +++ b/src/lifting_to_3d/linear_model.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# some code from https://raw.githubusercontent.com/weigq/3d_pose_baseline_pytorch/master/src/model.py + + +from __future__ import absolute_import +from __future__ import print_function +import torch +import torch.nn as nn + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +# from priors.vae_pose_model.vae_model import VAEmodel +from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior + + +def weight_init_dangerous(m): + # this is dangerous as it may overwrite the normalizing flow weights + if isinstance(m, nn.Linear): + nn.init.kaiming_normal(m.weight) + + +class Linear(nn.Module): + def __init__(self, linear_size, p_dropout=0.5): + super(Linear, self).__init__() + self.l_size = linear_size + + self.relu = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(p_dropout) + + self.w1 = nn.Linear(self.l_size, self.l_size) + self.batch_norm1 = nn.BatchNorm1d(self.l_size) + + self.w2 = nn.Linear(self.l_size, self.l_size) + self.batch_norm2 = nn.BatchNorm1d(self.l_size) + + def forward(self, x): + y = self.w1(x) + y = self.batch_norm1(y) + y = self.relu(y) + y = self.dropout(y) + y = self.w2(y) + y = self.batch_norm2(y) + y = self.relu(y) + y = self.dropout(y) + out = x + y + return out + + +class LinearModel(nn.Module): + def __init__(self, + linear_size=1024, + num_stage=2, + p_dropout=0.5, + input_size=16*2, + output_size=16*3): + super(LinearModel, self).__init__() + self.linear_size = linear_size + self.p_dropout = p_dropout + self.num_stage = num_stage + # input + self.input_size = input_size # 2d joints: 16 * 2 + # output + self.output_size = output_size # 3d joints: 16 * 3 + # process input to linear size + self.w1 = nn.Linear(self.input_size, self.linear_size) + self.batch_norm1 = nn.BatchNorm1d(self.linear_size) + self.linear_stages = [] + for l in range(num_stage): + self.linear_stages.append(Linear(self.linear_size, self.p_dropout)) + self.linear_stages = nn.ModuleList(self.linear_stages) + # post-processing + self.w2 = nn.Linear(self.linear_size, self.output_size) + # helpers (relu and dropout) + self.relu = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(self.p_dropout) + + def forward(self, x): + # pre-processing + y = self.w1(x) + y = self.batch_norm1(y) + y = self.relu(y) + y = self.dropout(y) + # linear layers + for i in range(self.num_stage): + y = self.linear_stages[i](y) + # post-processing + y = self.w2(y) + return y + + +class LinearModelComplete(nn.Module): + def __init__(self, + linear_size=1024, + num_stage_comb=2, + num_stage_heads=1, + num_stage_heads_pose=1, + trans_sep=False, + p_dropout=0.5, + input_size=16*2, + intermediate_size=1024, + output_info=None, + n_joints=25, + n_z=512, + add_z_to_3d_input=False, + n_segbps=64*2, + add_segbps_to_3d_input=False, + structure_pose_net='default', + fix_vae_weights=True, + nf_version=None): # 0): n_silh_enc + super(LinearModelComplete, self).__init__() + if add_z_to_3d_input: + self.n_z_to_add = n_z # 512 + else: + self.n_z_to_add = 0 + if add_segbps_to_3d_input: + self.n_segbps_to_add = n_segbps # 64 + else: + self.n_segbps_to_add = 0 + self.input_size = input_size + self.linear_size = linear_size + self.p_dropout = p_dropout + self.num_stage_comb = num_stage_comb + self.num_stage_heads = num_stage_heads + self.num_stage_heads_pose = num_stage_heads_pose + self.trans_sep = trans_sep + self.input_size = input_size + self.intermediate_size = intermediate_size + self.structure_pose_net = structure_pose_net + self.fix_vae_weights = fix_vae_weights # only relevant if structure_pose_net='vae' + self.nf_version = nf_version + if output_info is None: + pose = {'name': 'pose', 'n': n_joints*6, 'out_shape':[n_joints, 6]} + cam = {'name': 'flength', 'n': 1} + if self.trans_sep: + translation_xy = {'name': 'trans_xy', 'n': 2} + translation_z = {'name': 'trans_z', 'n': 1} + self.output_info = [pose, translation_xy, translation_z, cam] + else: + translation = {'name': 'trans', 'n': 3} + self.output_info = [pose, translation, cam] + if self.structure_pose_net == 'vae' or self.structure_pose_net == 'normflow': + global_pose = {'name': 'global_pose', 'n': 1*6, 'out_shape':[1, 6]} + self.output_info.append(global_pose) + else: + self.output_info = output_info + self.linear_combined = LinearModel(linear_size=self.linear_size, + num_stage=self.num_stage_comb, + p_dropout=p_dropout, + input_size=self.input_size + self.n_segbps_to_add + self.n_z_to_add, ###### + output_size=self.intermediate_size) + self.output_info_linear_models = [] + for ind_el, element in enumerate(self.output_info): + if element['name'] == 'pose': + num_stage = self.num_stage_heads_pose + if self.structure_pose_net == 'default': + output_size_pose_lin = element['n'] + elif self.structure_pose_net == 'vae': + # load vae decoder + self.pose_vae_model = VAEmodel() + self.pose_vae_model.initialize_with_pretrained_weights() + # define the input size of the vae decoder + output_size_pose_lin = self.pose_vae_model.latent_size + elif self.structure_pose_net == 'normflow': + # the following will automatically be initialized + self.pose_normflow_model = NormalizingFlowPrior(nf_version=self.nf_version) + output_size_pose_lin = element['n'] - 6 # no global rotation + else: + raise NotImplementedError + self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, + num_stage=num_stage, + p_dropout=p_dropout, + input_size=self.intermediate_size, + output_size=output_size_pose_lin)) + else: + if element['name'] == 'global_pose': + num_stage = self.num_stage_heads_pose + else: + num_stage = self.num_stage_heads + self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, + num_stage=num_stage, + p_dropout=p_dropout, + input_size=self.intermediate_size, + output_size=element['n'])) + element['linear_model_index'] = ind_el + self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models) + + def forward(self, x): + device = x.device + # combined stage + if x.shape[1] == self.input_size + self.n_segbps_to_add + self.n_z_to_add: + y = self.linear_combined(x) + elif x.shape[1] == self.input_size + self.n_segbps_to_add: + x_mod = torch.cat((x, torch.normal(0, 1, size=(x.shape[0], self.n_z_to_add)).to(device)), dim=1) + y = self.linear_combined(x_mod) + else: + print(x.shape) + print(self.input_size) + print(self.n_segbps_to_add) + print(self.n_z_to_add) + raise ValueError + # heads + results = {} + results_trans = {} + for element in self.output_info: + linear_model = self.output_info_linear_models[element['linear_model_index']] + if element['name'] == 'pose': + if self.structure_pose_net == 'default': + results['pose'] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + normflow_z = None + elif self.structure_pose_net == 'vae': + res_lin = linear_model(y) + if self.fix_vae_weights: + self.pose_vae_model.requires_grad_(False) # let gradients flow through but don't update the parameters + res_vae = self.pose_vae_model.inference(feat=res_lin) + self.pose_vae_model.requires_grad_(True) + else: + res_vae = self.pose_vae_model.inference(feat=res_lin) + res_pose_not_glob = res_vae.reshape((-1, element['out_shape'][0], element['out_shape'][1])) + normflow_z = None + elif self.structure_pose_net == 'normflow': + normflow_z = linear_model(y)*0.1 + self.pose_normflow_model.requires_grad_(False) # let gradients flow though but don't update the parameters + res_pose_not_glob = self.pose_normflow_model.run_backwards(z=normflow_z).reshape((-1, element['out_shape'][0]-1, element['out_shape'][1])) + else: + raise NotImplementedError + elif element['name'] == 'global_pose': + res_pose_glob = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + elif element['name'] == 'trans_xy' or element['name'] == 'trans_z': + results_trans[element['name']] = linear_model(y) + else: + results[element['name']] = linear_model(y) + if self.trans_sep: + results['trans'] = torch.cat((results_trans['trans_xy'], results_trans['trans_z']), dim=1) + # prepare pose including global rotation + if self.structure_pose_net == 'vae': + # results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob), dim=1) + results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, 1:, :]), dim=1) + elif self.structure_pose_net == 'normflow': + results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, :, :]), dim=1) + # return a dictionary which contains all results + results['normflow_z'] = normflow_z + return results # this is a dictionary + + + + + +# ------------------------------------------ +# for pretraining of the 3d model only: +# (see combined_model/model_shape_v2.py) + +class Wrapper_LinearModelComplete(nn.Module): + def __init__(self, + linear_size=1024, + num_stage_comb=2, + num_stage_heads=1, + num_stage_heads_pose=1, + trans_sep=False, + p_dropout=0.5, + input_size=16*2, + intermediate_size=1024, + output_info=None, + n_joints=25, + n_z=512, + add_z_to_3d_input=False, + n_segbps=64*2, + add_segbps_to_3d_input=False, + structure_pose_net='default', + fix_vae_weights=True, + nf_version=None): + self.add_segbps_to_3d_input = add_segbps_to_3d_input + super(Wrapper_LinearModelComplete, self).__init__() + self.model_3d = LinearModelComplete(linear_size=linear_size, + num_stage_comb=num_stage_comb, + num_stage_heads=num_stage_heads, + num_stage_heads_pose=num_stage_heads_pose, + trans_sep=trans_sep, + p_dropout=p_dropout, # 0.5, + input_size=input_size, + intermediate_size=intermediate_size, + output_info=output_info, + n_joints=n_joints, + n_z=n_z, + add_z_to_3d_input=add_z_to_3d_input, + n_segbps=n_segbps, + add_segbps_to_3d_input=add_segbps_to_3d_input, + structure_pose_net=structure_pose_net, + fix_vae_weights=fix_vae_weights, + nf_version=nf_version) + def forward(self, input_vec): + # input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) + # predict 3d parameters (those are normalized, we need to correct mean and std in a next step) + output = self.model_3d(input_vec) + return output \ No newline at end of file diff --git a/src/lifting_to_3d/utils/geometry_utils.py b/src/lifting_to_3d/utils/geometry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b6cc93db854a173792e7a545761abf8f1825a2 --- /dev/null +++ b/src/lifting_to_3d/utils/geometry_utils.py @@ -0,0 +1,240 @@ + +import torch +from torch.nn import functional as F +import numpy as np +from torch import nn + + +def geodesic_loss(R, Rgt, eps=1e-7): + # see: Silvia tiger pose model 3d code + # other implementations could be found here: + # 1.) https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py + # 2.) https://github.com/airalcorn2/pytorch-geodesic-loss/blob/master/geodesic_loss.py + num_joints = R.shape[1] + RT = R.permute(0,1,3,2) + A = torch.matmul(RT.view(-1,3,3),Rgt.view(-1,3,3)) + # torch.trace works only for 2D tensors + n = A.shape[0] + po_loss = 0 + T = torch.sum(A[:,torch.eye(3).bool()],1) + theta = torch.clamp(0.5*(T-1), -1+eps, 1-eps) + angles = torch.acos(theta) + loss = torch.sum(angles)/(n*num_joints) + return loss + +class geodesic_loss_R(nn.Module): + def __init__(self, reduction='mean'): + super(geodesic_loss_R, self).__init__() + self.reduction = reduction + self.eps = 1e-7 # 1e-6 + + # batch geodesic loss for rotation matrices + def bgdR(self,bRgts,bRps): + #return((bRgts - bRps)**2.).mean() + return geodesic_loss(bRgts, bRps, eps=self.eps) + + def forward(self, ypred, ytrue): + theta = geodesic_loss(ypred,ytrue,eps=self.eps) + if self.reduction == 'mean': + return torch.mean(theta) + else: + return theta + +def batch_rodrigues_numpy(theta): + """ Code adapted from spin + Convert axis-angle representation to rotation matrix. + Remark: + this leads to the same result as kornia.angle_axis_to_rotation_matrix(theta) + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = np.linalg.norm(theta + 1e-8, ord = 2, axis = 1) + # angle = np.unsqueeze(l1norm, -1) + angle = l1norm.reshape((-1, 1)) + # normalized = np.div(theta, angle) + normalized = theta / angle + angle = angle * 0.5 + v_cos = np.cos(angle) + v_sin = np.sin(angle) + # quat = np.cat([v_cos, v_sin * normalized], dim = 1) + quat = np.concatenate([v_cos, v_sin * normalized], axis = 1) + return quat_to_rotmat_numpy(quat) + +def quat_to_rotmat_numpy(quat): + """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + # norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + norm_quat = norm_quat/np.linalg.norm(norm_quat, ord=2, axis=1, keepdims=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + B = quat.shape[0] + # w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + w2, x2, y2, z2 = w**2, x**2, y**2, z**2 + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + rotMat = np.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], axis=1).reshape(B, 3, 3) + return rotMat + + +def batch_rodrigues(theta): + """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert axis-angle representation to rotation matrix. + Remark: + this leads to the same result as kornia.angle_axis_to_rotation_matrix(theta) + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def batch_rot2aa(Rs, epsilon=1e-7): + """ Code from: https://github.com/vchoutas/expose/blob/dffc38d62ad3817481d15fe509a93c2bb606cb8b/expose/utils/rotation_utils.py#L55 + Rs is B x 3 x 3 + void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis, + double& out_theta) + { + double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1); + c = cMathUtil::Clamp(c, -1.0, 1.0); + out_theta = std::acos(c); + if (std::abs(out_theta) < 0.00001) + { + out_axis = tVector(0, 0, 1, 0); + } + else + { + double m21 = mat(2, 1) - mat(1, 2); + double m02 = mat(0, 2) - mat(2, 0); + double m10 = mat(1, 0) - mat(0, 1); + double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10); + out_axis[0] = m21 / denom; + out_axis[1] = m02 / denom; + out_axis[2] = m10 / denom; + out_axis[3] = 0; + } + } + """ + cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1) + cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon) + theta = torch.acos(cos) + m21 = Rs[:, 2, 1] - Rs[:, 1, 2] + m02 = Rs[:, 0, 2] - Rs[:, 2, 0] + m10 = Rs[:, 1, 0] - Rs[:, 0, 1] + denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon) + axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom) + axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom) + axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom) + return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1) + +def quat_to_rotmat(quat): + """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def rot6d_to_rotmat(rot6d): + """ Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + assert rot6d.ndim == 2 + rot6d = rot6d.view(-1,3,2) + a1 = rot6d[:, :, 0] + a2 = rot6d[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + rotmat = torch.stack((b1, b2, b3), dim=-1) + return rotmat + +def rotmat_to_rot6d(rotmat): + """ Convert 3x3 rotation matrix to 6D rotation representation. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,3,3) Batch of corresponding rotation matrices + Output: + (B,6) Batch of 6-D rotation representations + """ + assert rotmat.ndim == 3 + rot6d = rotmat[:, :, :2].reshape((-1, 6)) + return rot6d + + +def main(): + # rotation matrix and 6d representation + # see "On the Continuity of Rotation Representations in Neural Networks" + from pyquaternion import Quaternion + batch_size = 5 + rotmat = np.zeros((batch_size, 3, 3)) + for ind in range(0, batch_size): + rotmat[ind, :, :] = Quaternion.random().rotation_matrix + rotmat_torch = torch.Tensor(rotmat) + rot6d = rotmat_to_rot6d(rotmat_torch) + rotmat_rec = rot6d_to_rotmat(rot6d) + print('..................... 1 ....................') + print(rotmat_torch[0, :, :]) + print(rotmat_rec[0, :, :]) + print('Conversion from rotmat to rot6d and inverse are ok!') + # rotation matrix and axis angle representation + import kornia + input = torch.rand(1, 3) + output = kornia.angle_axis_to_rotation_matrix(input) + input_rec = kornia.rotation_matrix_to_angle_axis(output) + print('..................... 2 ....................') + print(input) + print(input_rec) + print('Kornia implementation for rotation_matrix_to_angle_axis is wrong!!!!') + # For non-differential conversions use scipy: + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html + from scipy.spatial.transform import Rotation as R + r = R.from_matrix(rotmat[0, :, :]) + print('..................... 3 ....................') + print(r.as_matrix()) + print(r.as_rotvec()) + print(r.as_quaternion) + # one might furthermore have a look at: + # https://github.com/silviazuffi/smalst/blob/master/utils/transformations.py + + + +if __name__ == "__main__": + main() + + diff --git a/src/metrics/metrics.py b/src/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa1ae1c00bd286f55a4ede8565dc3eb619162a9 --- /dev/null +++ b/src/metrics/metrics.py @@ -0,0 +1,74 @@ +# code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py + + +import torch +import torch.nn.functional as F +import numpy as np + +IMG_RES = 256 # in WLDO it is 224 + +class Metrics(): + + @staticmethod + def PCK_thresh( + pred_keypoints, gt_keypoints, + gtseg, has_seg, + thresh, idxs, biggs=False): + + pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg] + + if idxs is None: + idxs = list(range(pred_keypoints.shape[1])) + + idxs = np.array(idxs).astype(int) + + pred_keypoints = pred_keypoints[:, idxs] + gt_keypoints = gt_keypoints[:, idxs] + + if biggs: + keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES + dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1) + else: + keypoints_gt = gt_keypoints # (0 to IMG_SIZE) + dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1) + + seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1) + + hits = (dist / torch.sqrt(seg_area)) < thresh + total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1) + pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible + + return pck + + @staticmethod + def PCK( + pred_keypoints, keypoints, + gtseg, has_seg, + thresh_range=[0.15], + idxs:list=None, + biggs=False): + """Calc PCK with same method as in eval. + idxs = optional list of subset of keypoints to index from + """ + cumulative_pck = [] + for thresh in thresh_range: + pck = Metrics.PCK_thresh( + pred_keypoints, keypoints, + gtseg, has_seg, thresh, idxs, + biggs=biggs) + cumulative_pck.append(pck) + pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0) + return pck_mean + + @staticmethod + def IOU(synth_silhouettes, gt_seg, img_border_mask, mask): + for i in range(mask.shape[0]): + synth_silhouettes[i] *= mask[i] + # Do not penalize parts of the segmentation outside the img range + gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask) + intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1) + union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1) + acc_IOU_SCORE = intersection / union + if torch.isnan(acc_IOU_SCORE).sum() > 0: + import pdb; pdb.set_trace() + return acc_IOU_SCORE \ No newline at end of file diff --git a/src/priors/helper_3dcgmodel_loss.py b/src/priors/helper_3dcgmodel_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5b16a6a78650a73ecf638a0242159b4589a38f0b --- /dev/null +++ b/src/priors/helper_3dcgmodel_loss.py @@ -0,0 +1,60 @@ + +import pickle as pkl +import torch + +# see also /is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data/new_dog_models/additional_info/debugging_only_info_scanned_toys_for_dog_model_creation.py + + +def load_dog_betas_for_3dcgmodel_loss(data_path, smal_model_type): + assert smal_model_type in {'barc', '39dogs_diffsize', '39dogs_norm', '39dogs_norm_newv2', '39dogs_norm_newv3'} + # load betas for the figures which were used to create the dog model + if smal_model_type in ['barc', '39dogs_norm', '39dogs_norm_newv2', '39dogs_norm_newv3']: + with open(data_path, 'rb') as f: + data = pkl.load(f) + dog_betas_unity = data['dogs_betas'] + elif smal_model_type == '39dogs_diffsize': + with open(data_path, 'rb') as f: + u = pkl._Unpickler(f) + u.encoding = 'latin1' + data = u.load() + dog_betas_unity = data['toys_betas'] + # load correspondencies between those betas and the breeds + if smal_model_type == 'barc': + dog_betas_for_3dcgloss = {29: torch.tensor(dog_betas_unity[0, :]).float(), + 91: torch.tensor(dog_betas_unity[1, :]).float(), + 84: torch.tensor(0.5*dog_betas_unity[3, :] + 0.5*dog_betas_unity[14, :]).float(), + 85: torch.tensor(dog_betas_unity[5, :]).float(), + 28: torch.tensor(dog_betas_unity[6, :]).float(), + 94: torch.tensor(dog_betas_unity[7, :]).float(), + 92: torch.tensor(dog_betas_unity[8, :]).float(), + 95: torch.tensor(dog_betas_unity[10, :]).float(), + 20: torch.tensor(dog_betas_unity[11, :]).float(), + 83: torch.tensor(dog_betas_unity[12, :]).float(), + 99: torch.tensor(dog_betas_unity[16, :]).float()} + elif smal_model_type in ['39dogs_diffsize', '39dogs_norm', '39dogs_norm_newv2', '39dogs_norm_newv3']: + dog_betas_for_3dcgloss = {84: torch.tensor(dog_betas_unity[0, :]).float(), + 99: torch.tensor(dog_betas_unity[2, :]).float(), + 81: torch.tensor(dog_betas_unity[6, :]).float(), + 9: torch.tensor(dog_betas_unity[9, :]).float(), + 40: torch.tensor(dog_betas_unity[10, :]).float(), + 29: torch.tensor(dog_betas_unity[11, :]).float(), + 10: torch.tensor(dog_betas_unity[13, :]).float(), + 11: torch.tensor(dog_betas_unity[14, :]).float(), + 44: torch.tensor(dog_betas_unity[15, :]).float(), + 91: torch.tensor(dog_betas_unity[16, :]).float(), + 28: torch.tensor(dog_betas_unity[17, :]).float(), + 108: torch.tensor(dog_betas_unity[20, :]).float(), + 80: torch.tensor(dog_betas_unity[21, :]).float(), + 85: torch.tensor(dog_betas_unity[23, :]).float(), + 68: torch.tensor(dog_betas_unity[24, :]).float(), + 94: torch.tensor(dog_betas_unity[25, :]).float(), + 95: torch.tensor(dog_betas_unity[26, :]).float(), + 20: torch.tensor(dog_betas_unity[27, :]).float(), + 62: torch.tensor(dog_betas_unity[28, :]).float(), + 57: torch.tensor(dog_betas_unity[30, :]).float(), + 102: torch.tensor(dog_betas_unity[31, :]).float(), + 8: torch.tensor(dog_betas_unity[35, :]).float(), + 83: torch.tensor(dog_betas_unity[36, :]).float(), + 96: torch.tensor(dog_betas_unity[37, :]).float(), + 46: torch.tensor(dog_betas_unity[38, :]).float()} + return dog_betas_for_3dcgloss \ No newline at end of file diff --git a/src/priors/normalizing_flow_prior/normalizing_flow_prior.py b/src/priors/normalizing_flow_prior/normalizing_flow_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf60fe51d722c31d7b045a637e1b57d4b577091 --- /dev/null +++ b/src/priors/normalizing_flow_prior/normalizing_flow_prior.py @@ -0,0 +1,115 @@ + +from torch import distributions +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributions import Normal +import numpy as np +import cv2 +import trimesh +from tqdm import tqdm + +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import FrEIA.framework as Ff +import FrEIA.modules as Fm +from configs.barc_cfg_defaults import get_cfg_global_updated + + +class NormalizingFlowPrior(nn.Module): + def __init__(self, nf_version=None): + super(NormalizingFlowPrior, self).__init__() + # the normalizing flow network takes as input a vector of size (35-1)*6 which is + # [all joints except root joint]*6. At the moment the rotation is represented as 6D + # representation, which is actually not ideal. Nevertheless, in practice the + # results seem to be ok. + n_dim = (35 - 1) * 6 + self.param_dict = self.get_version_param_dict(nf_version) + self.model_inn = self.build_inn_network(n_dim, k_tot=self.param_dict['k_tot']) + self.initialize_with_pretrained_weights() + + def get_version_param_dict(self, nf_version): + # we had trained several version of the normalizing flow pose prior, here we just provide + # the option that was user for the cvpr 2022 paper (nf_version=3) + if nf_version == 3: + param_dict = { + 'k_tot': 2, + 'path_pretrained': get_cfg_global_updated().paths.MODELPATH_NORMFLOW, + 'subnet_fc_type': '3_64'} + else: + print(nf_version) + raise ValueError + return param_dict + + def initialize_with_pretrained_weights(self, weight_path=None): + # The normalizing flow pose prior is pretrained separately. Afterwards all weights + # are kept fixed. Here we load those pretrained weights. + if weight_path is None: + weight_path = self.param_dict['path_pretrained'] + print(' normalizing flow pose prior: loading {}..'.format(weight_path)) + pretrained_dict = torch.load(weight_path)['model_state_dict'] + self.model_inn.load_state_dict(pretrained_dict, strict=True) + + def subnet_fc(self, c_in, c_out): + if self.param_dict['subnet_fc_type']=='3_512': + subnet = nn.Sequential(nn.Linear(c_in, 512), nn.ReLU(), + nn.Linear(512, 512), nn.ReLU(), + nn.Linear(512, c_out)) + elif self.param_dict['subnet_fc_type']=='3_64': + subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(), + nn.Linear(64, 64), nn.ReLU(), + nn.Linear(64, c_out)) + return subnet + + def build_inn_network(self, n_input, k_tot=12, verbose=False): + coupling_block = Fm.RNVPCouplingBlock + nodes = [Ff.InputNode(n_input, name='input')] + for k in range(k_tot): + nodes.append(Ff.Node(nodes[-1], + coupling_block, + {'subnet_constructor':self.subnet_fc, 'clamp':2.0}, + name=F'coupling_{k}')) + nodes.append(Ff.Node(nodes[-1], + Fm.PermuteRandom, + {'seed':k}, + name=F'permute_{k}')) + nodes.append(Ff.OutputNode(nodes[-1], name='output')) + model = Ff.ReversibleGraphNet(nodes, verbose=verbose) + return model + + def calculate_loss_from_z(self, z, type='square'): + assert type in ['square', 'neg_log_prob'] + if type == 'square': + loss = (z**2).mean() # * 0.00001 + elif type == 'neg_log_prob': + means = torch.zeros((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device) + stds = torch.ones((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device) + normal_distribution = Normal(means, stds) + log_prob = normal_distribution.log_prob(z) + loss = - log_prob.mean() + return loss + + def calculate_loss(self, poses_rot6d, type='square'): + assert type in ['square', 'neg_log_prob'] + poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6)) + z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False) + loss = self.calculate_loss_from_z(z, type=type) + return loss + + def forward(self, poses_rot6d): + # from pose to latent pose representation z + # poses_rot6d has shape (bs, 34, 6) + poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6)) + z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False) + return z + + def run_backwards(self, z): + # from latent pose representation z to pose + poses_rot6d_noglob, _ = self.model_inn(z, rev=True, jac=False) + return poses_rot6d_noglob + + + + + \ No newline at end of file diff --git a/src/priors/shape_prior.py b/src/priors/shape_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..15dd80c36bc02250f1655d6675115198cf065b19 --- /dev/null +++ b/src/priors/shape_prior.py @@ -0,0 +1,46 @@ + +# some parts of the code adapted from https://github.com/benjiebob/WLDO and https://github.com/benjiebob/SMALify + +import numpy as np +import torch +import pickle as pkl + + + +class ShapePrior(torch.nn.Module): + def __init__(self, prior_path): + super(ShapePrior, self).__init__() + try: + with open(prior_path, 'r') as f: + res = pkl.load(f) + except (UnicodeDecodeError, TypeError) as e: + with open(prior_path, 'rb') as file: + u = pkl._Unpickler(file) + u.encoding = 'latin1' + res = u.load() + if 'dog_cluster_mean' in res.keys(): + betas_mean = res['dog_cluster_mean'] # (54,) + betas_cov = res['dog_cluster_cov'] # (54, 54) + else: # for silvia's model + assert res['cluster_means'].shape[0]==1 + betas_mean = res['cluster_means'][0, :] # (39,) + betas_cov = res['cluster_cov'][0] # (39, 39) + + single_gaussian_inv_covs = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0])) + single_gaussian_precs = torch.tensor(np.linalg.cholesky(single_gaussian_inv_covs)).float() + single_gaussian_means = torch.tensor(betas_mean).float() + self.register_buffer('single_gaussian_precs', single_gaussian_precs) # (20, 20) + self.register_buffer('single_gaussian_means', single_gaussian_means) # (20) + use_ind_tch = torch.from_numpy(np.ones(single_gaussian_means.shape[0], dtype=bool)).float() # .to(device) + self.register_buffer('use_ind_tch', use_ind_tch) + + def forward(self, betas_smal_orig, use_singe_gaussian=False): + n_betas_smal = betas_smal_orig.shape[1] + device = betas_smal_orig.device + use_ind_tch_corrected = self.use_ind_tch * torch.cat((torch.ones_like(self.use_ind_tch[:n_betas_smal]), torch.zeros_like(self.use_ind_tch[n_betas_smal:]))) + samples = torch.cat((betas_smal_orig, torch.zeros((betas_smal_orig.shape[0], self.single_gaussian_means.shape[0]-n_betas_smal)).float().to(device)), dim=1) + mean_sub = samples - self.single_gaussian_means.unsqueeze(0) + single_gaussian_precs_corr = self.single_gaussian_precs * use_ind_tch_corrected[:, None] * use_ind_tch_corrected[None, :] + res = torch.tensordot(mean_sub, single_gaussian_precs_corr, dims = ([1], [0])) + res_final_mean_2 = torch.mean(res ** 2) + return res_final_mean_2 diff --git a/src/smal_pytorch/renderer/differentiable_renderer.py b/src/smal_pytorch/renderer/differentiable_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d76f1f34a16ee6b559e18d95ecaca4fa267b31 --- /dev/null +++ b/src/smal_pytorch/renderer/differentiable_renderer.py @@ -0,0 +1,280 @@ + +# part of the code from +# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py + +import torch +import torch.nn.functional as F +from scipy.io import loadmat +import numpy as np +# import config + +import pytorch3d +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + PerspectiveCameras, look_at_view_transform, look_at_rotation, + RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, + PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures, + DirectionalLights +) +from pytorch3d.renderer import TexturesVertex, SoftPhongShader +from pytorch3d.io import load_objs_as_meshes + +MESH_COLOR_0 = [0, 172, 223] +MESH_COLOR_1 = [172, 223, 0] + + +''' +Explanation of the shift between projection results from opendr and pytorch3d: + (0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr) + imagine you have an image of size 4: + middle of the first pixel is 0 + middle of the last pixel is 3 + => middle of the imgae would be 1.5 and not 2! + so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5) +To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps: + 1.) build camera matrix + K = np.array([[flength, 0, c_x], + [0, flength, c_y], + [0, 0, 1]], np.float) + 2.) we don't need to add extrinsics, as the mesh comes with translation (which is + added within smal_pytorch). all 3d points are already in the camera coordinate system. + -> projection reduces to p2d_proj = K*p3d + 3.) convert to pytorch3d conventions (0 in the middle of the first pixel) + p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.) +renderer.py - project_points_p3d: shows an example of what is described above, but + same focal length for the whole batch + +''' + +class SilhRenderer(torch.nn.Module): + def __init__(self, image_size, adapt_R_wldo=False): + super(SilhRenderer, self).__init__() + # see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315 + # adapt_R=True is True for all my experiments + # image_size: one number, integer + # ----- + # set mesh color + self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0)) + self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1)) + # prepare extrinsics, which in our case don't change + R = torch.Tensor(np.eye(3)).float()[None, :, :] + T = torch.Tensor(np.zeros((1, 3))).float() + if adapt_R_wldo: + R[0, 0, 0] = -1 + else: # used for all my own experiments + R[0, 0, 0] = -1 + R[0, 1, 1] = -1 + self.register_buffer('R', R) + self.register_buffer('T', T) + # prepare that part of the intrinsics which does not change either + # principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device) + # image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device) + self.img_size_scalar = image_size + self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float()) + self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float()) + # Rasterization settings for differentiable rendering, where the blur_radius + # initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable + # Renderer for Image-based 3D Reasoning', ICCV 2019 + self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4) + self.raster_settings_soft = RasterizationSettings( + image_size=image_size, # 128 + blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma, + faces_per_pixel=100) #50, + # Renderer for Image-based 3D Reasoning', body part segmentation + self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4) + self.raster_settings_soft_parts = RasterizationSettings( + image_size=image_size, # 128 + blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma, + faces_per_pixel=60) #50, + # settings for visualization renderer + self.raster_settings_vis = RasterizationSettings( + image_size=image_size, + blur_radius=0.0, + faces_per_pixel=1) + + def _get_cam(self, focal_lengths): + device = focal_lengths.device + bs = focal_lengths.shape[0] + if pytorch3d.__version__ == '0.2.5': + cameras = PerspectiveCameras(device=device, + focal_length=focal_lengths.repeat((1, 2)), + principal_point=self.principal_point.repeat((bs, 1)), + R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), + image_size=self.image_size.repeat((bs, 1))) + elif pytorch3d.__version__ == '0.6.1': + cameras = PerspectiveCameras(device=device, in_ndc=False, + focal_length=focal_lengths.repeat((1, 2)), + principal_point=self.principal_point.repeat((bs, 1)), + R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), + image_size=self.image_size.repeat((bs, 1))) + else: + print('this part depends on the version of pytorch3d, code was developed with 0.2.5') + raise ValueError + return cameras + + def _get_visualization_from_mesh(self, mesh, cameras, lights=None): + # color renderer for visualization + with torch.no_grad(): + device = mesh.device + # renderer for visualization + if lights is None: + lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) + vis_renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=self.raster_settings_vis), + shader=HardPhongShader( + device=device, + cameras=cameras, + lights=lights)) + # render image: + visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :] + return visualization + + + def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False): + tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + cameras = self._get_cam(focal_lengths) + # NEW: use the rasterizer to check vertex visibility + # see: https://github.com/facebookresearch/pytorch3d/issues/126 + # Get a rasterizer + if soft: + rasterizer = MeshRasterizer(cameras=cameras, + raster_settings=self.raster_settings_soft) + else: + rasterizer = MeshRasterizer(cameras=cameras, + raster_settings=self.raster_settings_vis) + # Get the output from rasterization + fragments = rasterizer(mesh) + # pix_to_face is of shape (N, H, W, 1) + pix_to_face = fragments.pix_to_face + # (F, 3) where F is the total number of faces across all the meshes in the batch + packed_faces = mesh.faces_packed() + # (V, 3) where V is the total number of verts across all the meshes in the batch + packed_verts = mesh.verts_packed() + vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,) + # Indices of unique visible faces + visible_faces = pix_to_face.unique() # [0] # (num_visible_faces ) + # Get Indices of unique visible verts using the vertex indices in the faces + visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3) + unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, ) + # Update visibility indicator to 1 for all visible vertices + vertex_visibility_map[unique_visible_verts_idx] = 1.0 + # since all meshes have the same amount of vertices, we can reshape the result + bs = vertices.shape[0] + vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1)) + return pix_to_face, vertex_visibility_map_resh + + + def get_torch_meshes(self, vertices, faces, color=0): + # create pytorch mesh + if color == 0: + mesh_color = self.mesh_color_0 + else: + mesh_color = self.mesh_color_1 + tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + return mesh + + + def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0): + # vertices: torch.Size([bs, 3889, 3]) + # faces: torch.Size([bs, 7774, 3]), int + # focal_lengths: torch.Size([bs, 1]) + device = vertices.device + # create cameras + cameras = self._get_cam(focal_lengths) + # create pytorch mesh + if color == 0: + mesh_color = self.mesh_color_0 # blue + elif color == 1: + mesh_color = self.mesh_color_1 + elif color == 2: + MESH_COLOR_2 = [240, 250, 240] # white + mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) + elif color == 3: + # MESH_COLOR_3 = [223, 0, 172] # pink + # MESH_COLOR_3 = [245, 245, 220] # beige + MESH_COLOR_3 = [166, 173, 164] + mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device) + else: + MESH_COLOR_2 = [240, 250, 240] + mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) + tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + # render mesh (no gradients) + # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) + # lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]]) + lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]]) + visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights) + return visualization + + def project_points(self, points, focal_lengths=None, cameras=None): + # points: torch.Size([bs, n_points, 3]) + # either focal_lengths or cameras is needed: + # focal_lenghts: torch.Size([bs, 1]) + # cameras: pytorch camera, for example PerspectiveCameras() + bs = points.shape[0] + device = points.device + screen_size = self.image_size.repeat((bs, 1)) + if cameras is None: + cameras = self._get_cam(focal_lengths) + if pytorch3d.__version__ == '0.2.5': + proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission) + elif pytorch3d.__version__ == '0.6.1': + proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]] + else: + print('this part depends on the version of pytorch3d, code was developed with 0.2.5') + raise ValueError + # flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth + proj_points = torch.flip(proj_points_orig, [2]) + # --- project points 'manually' + # j_proj = project_points_p3d(image_size, focal_length, points, device) + return proj_points + + def forward(self, vertices, points, faces, focal_lengths, color=None): + # vertices: torch.Size([bs, 3889, 3]) + # points: torch.Size([bs, n_points, 3]) (or None) + # faces: torch.Size([bs, 7774, 3]), int + # focal_lengths: torch.Size([bs, 1]) + # color: if None we don't render a visualization, else it should + # either be 0 or 1 + # ---> important: results are around 0.5 pixels off compared to chumpy! + # have a look at renderer.py for an explanation + # create cameras + cameras = self._get_cam(focal_lengths) + # create pytorch mesh + if color is None or color == 0: + mesh_color = self.mesh_color_0 + else: + mesh_color = self.mesh_color_1 + tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) + textures = Textures(verts_rgb=tex) + mesh = Meshes(verts=vertices, faces=faces, textures=textures) + # silhouette renderer + renderer_silh = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=self.raster_settings_soft), + shader=SoftSilhouetteShader(blend_params=self.blend_params)) + # project silhouette + silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1) + # project points + if points is None: + proj_points = None + else: + proj_points = self.project_points(points=points, cameras=cameras) + if color is not None: + # color renderer for visualization (no gradients) + visualization = self._get_visualization_from_mesh(mesh, cameras) + return silh_images, proj_points, visualization + else: + return silh_images, proj_points + + + + diff --git a/src/smal_pytorch/smal_model/batch_lbs.py b/src/smal_pytorch/smal_model/batch_lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..56e9c7d4740dc5eb9f95627fed47a28065dac3fd --- /dev/null +++ b/src/smal_pytorch/smal_model/batch_lbs.py @@ -0,0 +1,313 @@ +''' +Adjusted version of other PyTorch implementation of the SMAL/SMPL model +see: + 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py + 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py +''' + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import numpy as np + + +def batch_skew(vec, batch_size=None): + """ + vec is N x 3, batch_size is int + + returns N x 3 x 3. Skew_sym version of each matrix. + """ + device = vec.device + if batch_size is None: + batch_size = vec.shape.as_list()[0] + col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7]) + indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1]) + updates = torch.reshape( + torch.stack( + [ + -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1], + vec[:, 0] + ], + dim=1), [-1]) + out_shape = [batch_size * 9] + res = torch.Tensor(np.zeros(out_shape[0])).to(device=device) + res[np.array(indices.flatten())] = updates + res = torch.reshape(res, [batch_size, 3, 3]) + + return res + + + +def batch_rodrigues(theta): + """ + Theta is Nx3 + """ + device = theta.device + batch_size = theta.shape[0] + + angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1) + r = (torch.div(theta, angle)).unsqueeze(-1) + + angle = angle.unsqueeze(-1) + cos = torch.cos(angle) + sin = torch.sin(angle) + + outer = torch.matmul(r, r.transpose(1,2)) + + eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).to(device=device) + H = batch_skew(r, batch_size=batch_size) + R = cos * eyes + (1 - cos) * outer + sin * H + + return R + +def batch_lrotmin(theta): + """ + Output of this is used to compute joint-to-pose blend shape mapping. + Equation 9 in SMPL paper. + + + Args: + pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints. + This includes the global rotation so K=24 + + Returns + diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted., + """ + # Ignore global rotation + theta = theta[:,3:] + + Rs = batch_rodrigues(torch.reshape(theta, [-1,3])) + lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207]) + + return lrotmin + +def batch_global_rigid_transformation(Rs, Js, parent, rotate_base=False): + """ + Computes absolute joint locations given pose. + + rotate_base: if True, rotates the global rotation by 90 deg in x axis. + if False, this is the original SMPL coordinate. + + Args: + Rs: N x 24 x 3 x 3 rotation vector of K joints + Js: N x 24 x 3, joint locations before posing + parent: 24 holding the parent id for each index + + Returns + new_J : `Tensor`: N x 24 x 3 location of absolute joints + A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. + """ + device = Rs.device + if rotate_base: + print('Flipping the SMPL coordinate frame!!!!') + rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile + root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) + else: + root_rotation = Rs[:, 0, :, :] + + # Now Js is N x 24 x 3 x 1 + Js = Js.unsqueeze(-1) + N = Rs.shape[0] + + def make_A(R, t): + # Rs is N x 3 x 3, ts is N x 3 x 1 + R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) + t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(device=device)], 1) + return torch.cat([R_homo, t_homo], 2) + + A0 = make_A(root_rotation, Js[:, 0]) + results = [A0] + for i in range(1, parent.shape[0]): + j_here = Js[:, i] - Js[:, parent[i]] + A_here = make_A(Rs[:, i], j_here) + res_here = torch.matmul( + results[parent[i]], A_here) + results.append(res_here) + + # 10 x 24 x 4 x 4 + results = torch.stack(results, dim=1) + + new_J = results[:, :, :3, 3] + + # --- Compute relative A: Skinning is based on + # how much the bone moved (not the final location of the bone) + # but (final_bone - init_bone) + # --- + Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).to(device=device)], 2) + init_bone = torch.matmul(results, Js_w0) + # Append empty 4 x 3: + init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) + A = results - init_bone + + return new_J, A + + +######################################################################################### + +def get_bone_length_scales(part_list, betas_logscale): + leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) + front_leg_joints = list(range(7,11)) + list(range(11,15)) + back_leg_joints = list(range(17,21)) + list(range(21,25)) + tail_joints = list(range(25, 32)) + ear_joints = [33, 34] + neck_joints = [15, 6] # ? + core_joints = [4, 5] # ? + mouth_joints = [16, 32] + log_scales = torch.zeros(betas_logscale.shape[0], 35).to(betas_logscale.device) + for ind, part in enumerate(part_list): + if part == 'legs_l': + log_scales[:, leg_joints] = betas_logscale[:, ind][:, None] + elif part == 'front_legs_l': + log_scales[:, front_leg_joints] = betas_logscale[:, ind][:, None] + elif part == 'back_legs_l': + log_scales[:, back_leg_joints] = betas_logscale[:, ind][:, None] + elif part == 'tail_l': + log_scales[:, tail_joints] = betas_logscale[:, ind][:, None] + elif part == 'ears_l': + log_scales[:, ear_joints] = betas_logscale[:, ind][:, None] + elif part == 'neck_l': + log_scales[:, neck_joints] = betas_logscale[:, ind][:, None] + elif part == 'core_l': + log_scales[:, core_joints] = betas_logscale[:, ind][:, None] + elif part == 'head_l': + log_scales[:, mouth_joints] = betas_logscale[:, ind][:, None] + else: + pass + all_scales = torch.exp(log_scales) + return all_scales[:, 1:] # don't count root + +def get_beta_scale_mask(part_list): + # which joints belong to which bodypart + leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) + front_leg_joints = list(range(7,11)) + list(range(11,15)) + back_leg_joints = list(range(17,21)) + list(range(21,25)) + tail_joints = list(range(25, 32)) + ear_joints = [33, 34] + neck_joints = [15, 6] # ? + core_joints = [4, 5] # ? + mouth_joints = [16, 32] + n_b_log = len(part_list) #betas_logscale.shape[1] # 8 # 6 + beta_scale_mask = torch.zeros(35, 3, n_b_log) # .to(betas_logscale.device) + for ind, part in enumerate(part_list): + if part == 'legs_l': + beta_scale_mask[leg_joints, [2], [ind]] = 1.0 # Leg lengthening + elif part == 'legs_f': + beta_scale_mask[leg_joints, [0], [ind]] = 1.0 # Leg fatness + beta_scale_mask[leg_joints, [1], [ind]] = 1.0 # Leg fatness + elif part == 'front_legs_l': + beta_scale_mask[front_leg_joints, [2], [ind]] = 1.0 # front Leg lengthening + elif part == 'front_legs_f': + beta_scale_mask[front_leg_joints, [0], [ind]] = 1.0 # front Leg fatness + beta_scale_mask[front_leg_joints, [1], [ind]] = 1.0 # front Leg fatness + elif part == 'back_legs_l': + beta_scale_mask[back_leg_joints, [2], [ind]] = 1.0 # back Leg lengthening + elif part == 'back_legs_f': + beta_scale_mask[back_leg_joints, [0], [ind]] = 1.0 # back Leg fatness + beta_scale_mask[back_leg_joints, [1], [ind]] = 1.0 # back Leg fatness + elif part == 'tail_l': + beta_scale_mask[tail_joints, [0], [ind]] = 1.0 # Tail lengthening + elif part == 'tail_f': + beta_scale_mask[tail_joints, [1], [ind]] = 1.0 # Tail fatness + beta_scale_mask[tail_joints, [2], [ind]] = 1.0 # Tail fatness + elif part == 'ears_y': + beta_scale_mask[ear_joints, [1], [ind]] = 1.0 # Ear y + elif part == 'ears_l': + beta_scale_mask[ear_joints, [2], [ind]] = 1.0 # Ear z + elif part == 'neck_l': + beta_scale_mask[neck_joints, [0], [ind]] = 1.0 # Neck lengthening + elif part == 'neck_f': + beta_scale_mask[neck_joints, [1], [ind]] = 1.0 # Neck fatness + beta_scale_mask[neck_joints, [2], [ind]] = 1.0 # Neck fatness + elif part == 'core_l': + beta_scale_mask[core_joints, [0], [ind]] = 1.0 # Core lengthening + # beta_scale_mask[core_joints, [1], [ind]] = 1.0 # Core fatness (height) + elif part == 'core_fs': + beta_scale_mask[core_joints, [2], [ind]] = 1.0 # Core fatness (side) + elif part == 'head_l': + beta_scale_mask[mouth_joints, [0], [ind]] = 1.0 # Head lengthening + elif part == 'head_f': + beta_scale_mask[mouth_joints, [1], [ind]] = 1.0 # Head fatness 0 + beta_scale_mask[mouth_joints, [2], [ind]] = 1.0 # Head fatness 1 + else: + print(part + ' not available') + raise ValueError + beta_scale_mask = torch.transpose( + beta_scale_mask.reshape(35*3, n_b_log), 0, 1) + return beta_scale_mask + +def batch_global_rigid_transformation_biggs(Rs, Js, parent, scale_factors_3x3, rotate_base = False, betas_logscale=None, opts=None): + """ + Computes absolute joint locations given pose. + + rotate_base: if True, rotates the global rotation by 90 deg in x axis. + if False, this is the original SMPL coordinate. + + Args: + Rs: N x 24 x 3 x 3 rotation vector of K joints + Js: N x 24 x 3, joint locations before posing + parent: 24 holding the parent id for each index + + Returns + new_J : `Tensor`: N x 24 x 3 location of absolute joints + A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. + """ + if rotate_base: + print('Flipping the SMPL coordinate frame!!!!') + rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile + root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) + else: + root_rotation = Rs[:, 0, :, :] + + # Now Js is N x 24 x 3 x 1 + Js = Js.unsqueeze(-1) + N = Rs.shape[0] + + Js_orig = Js.clone() + + def make_A(R, t): + # Rs is N x 3 x 3, ts is N x 3 x 1 + R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) + t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1) + return torch.cat([R_homo, t_homo], 2) + + A0 = make_A(root_rotation, Js[:, 0]) + results = [A0] + for i in range(1, parent.shape[0]): + j_here = Js[:, i] - Js[:, parent[i]] + try: + s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]]) + except: + # import pdb; pdb.set_trace() + s_par_inv = torch.max(scale_factors_3x3[:, parent[i]], 0.01*torch.eye((3))[None, :, :].to(scale_factors_3x3.device)) + rot = Rs[:, i] + s = scale_factors_3x3[:, i] + + rot_new = s_par_inv @ rot @ s + + A_here = make_A(rot_new, j_here) + res_here = torch.matmul( + results[parent[i]], A_here) + + results.append(res_here) + + # 10 x 24 x 4 x 4 + results = torch.stack(results, dim=1) + + # scale updates + new_J = results[:, :, :3, 3] + + # --- Compute relative A: Skinning is based on + # how much the bone moved (not the final location of the bone) + # but (final_bone - init_bone) + # --- + Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2) + init_bone = torch.matmul(results, Js_w0) + # Append empty 4 x 3: + init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) + A = results - init_bone + + return new_J, A \ No newline at end of file diff --git a/src/smal_pytorch/smal_model/smal_basics.py b/src/smal_pytorch/smal_model/smal_basics.py new file mode 100644 index 0000000000000000000000000000000000000000..dd83cbe64731830bcfde22e7252023ca097c5a5b --- /dev/null +++ b/src/smal_pytorch/smal_model/smal_basics.py @@ -0,0 +1,82 @@ +''' +Adjusted version of other PyTorch implementation of the SMAL/SMPL model +see: + 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py + 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py +''' + +import os +import pickle as pkl +import json +import numpy as np +import pickle as pkl + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.SMAL_configs import SMAL_DATA_DIR, SYMMETRY_INDS_FILE + +# model_dir = 'smalst/smpl_models/' +# FILE_DIR = os.path.dirname(os.path.realpath(__file__)) +model_dir = SMAL_DATA_DIR # os.path.join(FILE_DIR, '..', 'smpl_models/') +symmetry_inds_file = SYMMETRY_INDS_FILE # os.path.join(FILE_DIR, '..', 'smpl_models/symmetry_inds.json') +with open(symmetry_inds_file) as f: + symmetry_inds_dict = json.load(f) +LEFT_INDS = np.asarray(symmetry_inds_dict['left_inds']) +RIGHT_INDS = np.asarray(symmetry_inds_dict['right_inds']) +CENTER_INDS = np.asarray(symmetry_inds_dict['center_inds']) + + +def get_symmetry_indices(): + sym_dict = {'left': LEFT_INDS, + 'right': RIGHT_INDS, + 'center': CENTER_INDS} + return sym_dict + +def verify_symmetry(shapedirs, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS): + # shapedirs: (3889, 3, n_sh) + assert (shapedirs[center_inds, 1, :] == 0.0).all() + assert (shapedirs[right_inds, 1, :] == -shapedirs[left_inds, 1, :]).all() + return + +def from_shapedirs_to_shapedirs_half(shapedirs, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS, verify=False): + # shapedirs: (3889, 3, n_sh) + # shapedirs_half: (2012, 3, n_sh) + selected_inds = np.concatenate((center_inds, left_inds), axis=0) + shapedirs_half = shapedirs[selected_inds, :, :] + if verify: + verify_symmetry(shapedirs) + else: + shapedirs_half[:center_inds.shape[0], 1, :] = 0.0 + return shapedirs_half + +def from_shapedirs_half_to_shapedirs(shapedirs_half, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS): + # shapedirs_half: (2012, 3, n_sh) + # shapedirs: (3889, 3, n_sh) + shapedirs = np.zeros((center_inds.shape[0] + 2*left_inds.shape[0], 3, shapedirs_half.shape[2])) + shapedirs[center_inds, :, :] = shapedirs_half[:center_inds.shape[0], :, :] + shapedirs[left_inds, :, :] = shapedirs_half[center_inds.shape[0]:, :, :] + shapedirs[right_inds, :, :] = shapedirs_half[center_inds.shape[0]:, :, :] + shapedirs[right_inds, 1, :] = - shapedirs_half[center_inds.shape[0]:, 1, :] + return shapedirs + +def align_smal_template_to_symmetry_axis(v, subtract_mean=True): + # These are the indexes of the points that are on the symmetry axis + I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003] + if subtract_mean: + v = v - np.mean(v) + y = np.mean(v[I,1]) + v[:,1] = v[:,1] - y + v[I,1] = 0 + left_inds = LEFT_INDS + right_inds = RIGHT_INDS + center_inds = CENTER_INDS + v[right_inds, :] = np.array([1,-1,1])*v[left_inds, :] + try: + assert(len(left_inds) == len(right_inds)) + except: + import pdb; pdb.set_trace() + return v, left_inds, right_inds, center_inds + + + diff --git a/src/smal_pytorch/smal_model/smal_torch_new.py b/src/smal_pytorch/smal_model/smal_torch_new.py new file mode 100644 index 0000000000000000000000000000000000000000..32e1dbb57ad5a79fbeaa5448d392edbf55121a19 --- /dev/null +++ b/src/smal_pytorch/smal_model/smal_torch_new.py @@ -0,0 +1,471 @@ +""" +PyTorch implementation of the SMAL/SMPL model +see: + 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py + 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py +main changes compared to SMALST and WLDO: + * new model + (/ps/scratch/nrueegg/new_projects/side_packages/SMALify/new_smal_pca/results/my_tposeref_results_3/) + dogs are part of the pca to create the model + al meshes are centered around their root joint + the animals are all scaled such that their body length (butt to breast) is 1 + X_init = np.concatenate((vertices_dogs, vertices_smal), axis=0) # vertices_dogs + X = [] + for ind in range(0, X_init.shape[0]): + X_tmp, _, _, _ = align_smal_template_to_symmetry_axis(X_init[ind, :, :], subtract_mean=True) # not sure if this is necessary + X.append(X_tmp) + X = np.asarray(X) + # define points which will be used for normalization + idxs_front = [6, 16, 8, 964] # [1172, 6, 16, 8, 964] + idxs_back = [174, 2148, 175, 2149] # not in the middle, but pairs + reg_j = np.asarray(dd['J_regressor'].todense()) + # normalize the meshes such that X_frontback_dist is 1 and the root joint is in the center (0, 0, 0) + X_front = X[:, idxs_front, :].mean(axis=1) + X_back = X[:, idxs_back, :].mean(axis=1) + X_frontback_dist = np.sqrt(((X_front - X_back)**2).sum(axis=1)) + X = X / X_frontback_dist[:, None, None] + X_j0 = np.sum(X[:, reg_j[0, :]>0, :] * reg_j[0, (reg_j[0, :]>0)][None, :, None], axis=1) + X = X - X_j0[:, None, :] + * add limb length changes the same way as in WLDO + * overall scale factor is added +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import chumpy as ch +import os.path +from torch import nn +from torch.autograd import Variable +import pickle as pkl +from .batch_lbs import batch_rodrigues, batch_global_rigid_transformation, batch_global_rigid_transformation_biggs, get_bone_length_scales, get_beta_scale_mask + +from .smal_basics import align_smal_template_to_symmetry_axis, get_symmetry_indices + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from configs.SMAL_configs import KEY_VIDS, CANONICAL_MODEL_JOINTS, CANONICAL_MODEL_JOINTS_REFINED, IDXS_BONES_NO_REDUNDANCY # , SMAL_MODEL_PATH +# from configs.SMAL_configs import SMAL_MODEL_TYPE +from configs.SMAL_configs import SMAL_MODEL_CONFIG + +from smal_pytorch.utils import load_vertex_colors + + +# There are chumpy variables so convert them to numpy. +def undo_chumpy(x): + return x if isinstance(x, np.ndarray) else x.r + +# class SMAL(object): +class SMAL(nn.Module): + def __init__(self, pkl_path=None, smal_model_type=None, n_betas=None, template_name='neutral', use_smal_betas=True, logscale_part_list=None): + super(SMAL, self).__init__() + + # before: pkl_path=SMAL_MODEL_PATH + if smal_model_type is not None: + assert (pkl_path is None) + assert smal_model_type in SMAL_MODEL_CONFIG.keys() + pkl_path = SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path'] + self.smal_model_type = smal_model_type + if logscale_part_list is None: + logscale_part_list = SMAL_MODEL_CONFIG[smal_model_type]['logscale_part_list'] + elif (pkl_path is not None): + self.smal_model_type = None + elif (pkl_path is None): + smal_model_type = 'barc' + print('use default smal_model_type: ' + smal_model_type) + pkl_path = SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path'] + self.smal_model_type = smal_model_type + else: + raise ValueError + + + ''' + # save some information about the model if possible + if pkl_path == SMAL_MODEL_PATH: + self.smal_model_type = SMAL_MODEL_TYPE + ''' + + if logscale_part_list is None: + # logscale_part_list = ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'] + self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] + else: + self.logscale_part_list = logscale_part_list + self.betas_scale_mask = get_beta_scale_mask(part_list=self.logscale_part_list) + self.num_betas_logscale = len(self.logscale_part_list) + + self.use_smal_betas = use_smal_betas + + # -- Load SMPL params -- + try: + with open(pkl_path, 'r') as f: + dd = pkl.load(f) + except (UnicodeDecodeError, TypeError) as e: + with open(pkl_path, 'rb') as file: + u = pkl._Unpickler(file) + u.encoding = 'latin1' + dd = u.load() + + self.f = dd['f'] + self.register_buffer('faces', torch.from_numpy(self.f.astype(int))) + + # get the correct template (mean shape) + if template_name=='neutral': + v_template = dd['v_template'] + v = v_template + else: + raise NotImplementedError + + # Mean template vertices + self.register_buffer('v_template', torch.Tensor(v)) + # Size of mesh [Number of vertices, 3] + self.size = [self.v_template.shape[0], 3] + self.num_betas = dd['shapedirs'].shape[-1] + # symmetry indices + self.sym_ids_dict = get_symmetry_indices() + + # Shape blend shape basis + shapedir = np.reshape(undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T + shapedir.flags['WRITEABLE'] = True # not sure why this is necessary + self.register_buffer('shapedirs', torch.Tensor(shapedir)) + + # Regressor for joint locations given shape + self.register_buffer('J_regressor', torch.Tensor(dd['J_regressor'].T.todense())) + + # Pose blend shape basis + num_pose_basis = dd['posedirs'].shape[-1] + + posedirs = np.reshape(undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T + self.register_buffer('posedirs', torch.Tensor(posedirs)) + + # indices of parents for each joints + self.parents = dd['kintree_table'][0].astype(np.int32) + + # LBS weights + self.register_buffer('weights', torch.Tensor(undo_chumpy(dd['weights']))) + + # prepare for vertex offsets + self._prepare_for_vertex_offsets() + + + def _prepare_for_vertex_offsets(self): + sym_left_ids = self.sym_ids_dict['left'] + sym_right_ids = self.sym_ids_dict['right'] + sym_center_ids = self.sym_ids_dict['center'] + self.n_center = sym_center_ids.shape[0] + self.n_left = sym_left_ids.shape[0] + self.sl = 2*self.n_center # sl: start left + # get indices to go from half_shapedirs to shapedirs + inds_back = np.zeros((3889)) + for ind in range(0, sym_center_ids.shape[0]): + ind_in_forward = sym_center_ids[ind] + inds_back[ind_in_forward] = ind + for ind in range(0, sym_left_ids.shape[0]): + ind_in_forward = sym_left_ids[ind] + inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind + for ind in range(0, sym_right_ids.shape[0]): + ind_in_forward = sym_right_ids[ind] + inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind + # self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long()) + self.inds_back_torch = torch.Tensor(inds_back).long() + return + + + def _caclulate_bone_lengths_from_J(self, J, betas_logscale): + # NEW: calculate bone lengths: + all_bone_lengths_list = [] + for i in range(1, self.parents.shape[0]): + bone_vec = J[:, i] - J[:, self.parents[i]] + bone_length = torch.sqrt(torch.sum(bone_vec ** 2, axis=1)) + all_bone_lengths_list.append(bone_length) + all_bone_lengths = torch.stack(all_bone_lengths_list) + # some bones are pairs, it is enough to take one of the two bones + all_bone_length_scales = get_bone_length_scales(self.logscale_part_list, betas_logscale) + all_bone_lengths = all_bone_lengths.permute((1,0)) * all_bone_length_scales + + return all_bone_lengths #.permute((1,0)) + + + def caclulate_bone_lengths(self, beta, betas_logscale, shapedirs_sel=None, short=True): + nBetas = beta.shape[1] + + # 1. Add shape blend shapes + # do we use the original shapedirs or a new set of selected shapedirs? + if shapedirs_sel is None: + shapedirs_sel = self.shapedirs[:nBetas,:] + else: + assert shapedirs_sel.shape[0] == nBetas + v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]]) + + # 2. Infer shape-dependent joint locations. + Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) + Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) + Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) + J = torch.stack([Jx, Jy, Jz], dim=2) + + # calculate bone lengths + all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale) + selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY] + + if short: + return selected_bone_lengths + else: + return all_bone_lengths + + + + def __call__(self, beta, betas_limbs, theta=None, pose=None, trans=None, del_v=None, get_skin=True, keyp_conf='red', get_all_info=False, shapedirs_sel=None, vert_off_compact=None): + device = beta.device + + betas_logscale = betas_limbs + # NEW: allow that rotation is given as rotation matrices instead of axis angle rotation + # theta: BSxNJointsx3 or BSx(NJoints*3) + # pose: NxNJointsx3x3 + if (theta is None) and (pose is None): + raise ValueError("Either pose (rotation matrices NxNJointsx3x3) or theta (axis angle BSxNJointsx3) must be given") + elif (theta is not None) and (pose is not None): + raise ValueError("Not both pose (rotation matrices NxNJointsx3x3) and theta (axis angle BSxNJointsx3) can be given") + + if True: # self.use_smal_betas: + nBetas = beta.shape[1] + else: + nBetas = 0 + + # add possibility to have additional vertex offsets + if vert_off_compact is None: + vertex_offsets = torch.zeros_like(self.v_template) + else: + ########################################################## + # bs = 1 + # vert_off_compact = torch.zeros((bs, 2*self.n_center + 3*self.n_left), device=vert_off_compact.device, dtype=vert_off_compact.dtype) + if type(vert_off_compact) is dict: + zero_vec = torch.zeros((vert_off_compact['c0'].shape[0], self.n_center)).to(device) + half_vertex_offsets_center = torch.stack((vert_off_compact['c0'], \ + zero_vec, \ + vert_off_compact['c2']), axis=1) + half_vertex_offsets_left = torch.stack((vert_off_compact['l0'], \ + vert_off_compact['l1'], \ + vert_off_compact['l2']), axis=1) + half_vertex_offsets_right = torch.stack((vert_off_compact['l0'], \ + - vert_off_compact['l1'], \ + vert_off_compact['l2']), axis=1) + else: + zero_vec = torch.zeros((vert_off_compact.shape[0], self.n_center)).to(device) + half_vertex_offsets_center = torch.stack((vert_off_compact[:, :self.n_center], \ + zero_vec, \ + vert_off_compact[:, self.n_center:2*self.n_center]), axis=1) + half_vertex_offsets_left = torch.stack((vert_off_compact[:, self.sl:self.sl+self.n_left], \ + vert_off_compact[:, self.sl+self.n_left:self.sl+2*self.n_left], \ + vert_off_compact[:, self.sl+2*self.n_left:self.sl+3*self.n_left]), axis=1) + half_vertex_offsets_right = torch.stack((vert_off_compact[:, self.sl:self.sl+self.n_left], \ + - vert_off_compact[:, self.sl+self.n_left:self.sl+2*self.n_left], \ + vert_off_compact[:, self.sl+2*self.n_left:self.sl+3*self.n_left]), axis=1) + + half_vertex_offsets_tot = torch.cat((half_vertex_offsets_center, half_vertex_offsets_left, half_vertex_offsets_right), dim=2) # (bs, 3, 3889) + vertex_offsets = torch.index_select(half_vertex_offsets_tot, dim=2, index=self.inds_back_torch.to(half_vertex_offsets_tot.device)).permute((0, 2, 1)) # (bs, 3889, 3) + + + # 1. Add shape blend shapes + # do we use the original shapedirs or a new set of selected shapedirs? + if shapedirs_sel is None: + shapedirs_sel = self.shapedirs[:nBetas,:] + else: + assert shapedirs_sel.shape[0] == nBetas + + if nBetas > 0: + if del_v is None: + v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]]) + vertex_offsets + else: + v_shaped = self.v_template + del_v + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]]) + vertex_offsets + else: + if del_v is None: + v_shaped = self.v_template.unsqueeze(0) + vertex_offsets + else: + v_shaped = self.v_template + del_v + vertex_offsets + + # 2. Infer shape-dependent joint locations. + Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) + Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) + Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) + J = torch.stack([Jx, Jy, Jz], dim=2) + + # 3. Add pose blend shapes + # N x 24 x 3 x 3 + if pose is None: + Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3]) + else: + Rs = pose + # Ignore global rotation. + pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(device=device), [-1, 306]) + + v_posed = torch.reshape( + torch.matmul(pose_feature, self.posedirs), + [-1, self.size[0], self.size[1]]) + v_shaped + + #------------------------- + # new: add corrections of bone lengths to the template (before hypothetical pose blend shapes!) + # see biggs batch_lbs.py + betas_scale = torch.exp(betas_logscale @ self.betas_scale_mask.to(betas_logscale.device)) + scaling_factors = betas_scale.reshape(-1, 35, 3) + scale_factors_3x3 = torch.diag_embed(scaling_factors, dim1=-2, dim2=-1) + + # 4. Get the global joint location + # self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents) + self.J_transformed, A = batch_global_rigid_transformation_biggs(Rs, J, self.parents, scale_factors_3x3, betas_logscale=betas_logscale) + + # 2-BONES. Calculate bone lengths + all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale) + # selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY] + #------------------------- + + # 5. Do skinning: + num_batch = Rs.shape[0] + + weights_t = self.weights.repeat([num_batch, 1]) + W = torch.reshape(weights_t, [num_batch, -1, 35]) + + + T = torch.reshape( + torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])), + [num_batch, -1, 4, 4]) + v_posed_homo = torch.cat( + [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=device)], 2) + v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1)) + + verts = v_homo[:, :, :3, 0] + + if trans is None: + trans = torch.zeros((num_batch,3)).to(device=device) + + verts = verts + trans[:,None,:] + + # Get joints: + joint_x = torch.matmul(verts[:, :, 0], self.J_regressor) + joint_y = torch.matmul(verts[:, :, 1], self.J_regressor) + joint_z = torch.matmul(verts[:, :, 2], self.J_regressor) + joints = torch.stack([joint_x, joint_y, joint_z], dim=2) + + # New... (see https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py) + joints = torch.cat([ + joints, + verts[:, None, 1863], # end_of_nose + verts[:, None, 26], # chin + verts[:, None, 2124], # right ear tip + verts[:, None, 150], # left ear tip + verts[:, None, 3055], # left eye + verts[:, None, 1097], # right eye + # new: add paw keypoints, not joint locations -> bottom, rather in front + # remark: when i look in the animals direction, left and right are exchanged + verts[:, None, 1330], # front paw, right + verts[:, None, 3282], # front paw, left + verts[:, None, 1521], # back paw, right + verts[:, None, 3473], # back paw, left + verts[:, None, 6], # throat + verts[:, None, 20], # withers + ], dim = 1) + + + if keyp_conf == 'blue' or keyp_conf == 'dict': + # Generate keypoints + nLandmarks = KEY_VIDS.shape[0] # 24 + j3d = torch.zeros((num_batch, nLandmarks, 3)).to(device=device) + for j in range(nLandmarks): + j3d[:, j,:] = torch.mean(verts[:, KEY_VIDS[j],:], dim=1) # translation is already added to the vertices + joints_blue = j3d + + joints_red = joints[:, :-12, :] # joints[:, :-6, :] + joints_green = joints[:, CANONICAL_MODEL_JOINTS, :] + joints_olive = joints[:, CANONICAL_MODEL_JOINTS_REFINED, :] # same order but better paw, withers and throat keypoints + + if keyp_conf == 'red': + relevant_joints = joints_red + elif keyp_conf == 'green': + relevant_joints = joints_green + elif keyp_conf == 'olive': + relevant_joints = joints_olive + elif keyp_conf == 'blue': + relevant_joints = joints_blue + elif keyp_conf == 'dict': + relevant_joints = {'red': joints_red, + 'green': joints_green, + 'olive': joints_olive, + 'blue': joints_blue} + else: + raise NotImplementedError + + if get_all_info: + return verts, relevant_joints, Rs, all_bone_lengths + else: + if get_skin: + return verts, relevant_joints, Rs # , v_shaped + else: + return relevant_joints + + + + + + def get_joints_from_verts(self, verts, keyp_conf='red'): + + num_batch = verts.shape[0] + + # Get joints: + joint_x = torch.matmul(verts[:, :, 0], self.J_regressor) + joint_y = torch.matmul(verts[:, :, 1], self.J_regressor) + joint_z = torch.matmul(verts[:, :, 2], self.J_regressor) + joints = torch.stack([joint_x, joint_y, joint_z], dim=2) + + # New... (see https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py) + joints = torch.cat([ + joints, + verts[:, None, 1863], # end_of_nose + verts[:, None, 26], # chin + verts[:, None, 2124], # right ear tip + verts[:, None, 150], # left ear tip + verts[:, None, 3055], # left eye + verts[:, None, 1097], # right eye + # new: add paw keypoints, not joint locations -> bottom, rather in front + # remark: when i look in the animals direction, left and right are exchanged + verts[:, None, 1330], # front paw, right + verts[:, None, 3282], # front paw, left + verts[:, None, 1521], # back paw, right + verts[:, None, 3473], # back paw, left + verts[:, None, 6], # throat + verts[:, None, 20], # withers + ], dim = 1) + + + if keyp_conf == 'blue' or keyp_conf == 'dict': + # Generate keypoints + nLandmarks = KEY_VIDS.shape[0] # 24 + j3d = torch.zeros((num_batch, nLandmarks, 3)).to(device=device) + for j in range(nLandmarks): + j3d[:, j,:] = torch.mean(verts[:, KEY_VIDS[j],:], dim=1) # translation is already added to the vertices + joints_blue = j3d + + joints_red = joints[:, :-12, :] # joints[:, :-6, :] + joints_green = joints[:, CANONICAL_MODEL_JOINTS, :] + joints_olive = joints[:, CANONICAL_MODEL_JOINTS_REFINED, :] # same order but better paw, withers and throat keypoints + + if keyp_conf == 'red': + relevant_joints = joints_red + elif keyp_conf == 'green': + relevant_joints = joints_green + elif keyp_conf == 'olive': + relevant_joints = joints_olive + elif keyp_conf == 'blue': + relevant_joints = joints_blue + elif keyp_conf == 'dict': + relevant_joints = {'red': joints_red, + 'green': joints_green, + 'olive': joints_olive, + 'blue': joints_blue} + + return relevant_joints + + + + diff --git a/src/smal_pytorch/utils.py b/src/smal_pytorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11e48a0fe88cf27472c56cb7c9d3359984fd9b9a --- /dev/null +++ b/src/smal_pytorch/utils.py @@ -0,0 +1,13 @@ +import numpy as np + +def load_vertex_colors(obj_path): + v_colors = [] + for line in open(obj_path, "r"): + if line.startswith('#'): continue + values = line.split() + if not values: continue + if values[0] == 'v': + v_colors.append(values[4:7]) + else: + continue + return np.asarray(v_colors, dtype=np.float32) diff --git a/src/stacked_hourglass/__init__.py b/src/stacked_hourglass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb1a50e871a0da6d20f89aaf5d559d40bf5341c --- /dev/null +++ b/src/stacked_hourglass/__init__.py @@ -0,0 +1,5 @@ +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) +from src.stacked_hourglass.model import hg1, hg2, hg4, hg8 +from src.stacked_hourglass.predictor import HumanPosePredictor diff --git a/src/stacked_hourglass/datasets/__init__.py b/src/stacked_hourglass/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/stacked_hourglass/datasets/anipose.py b/src/stacked_hourglass/datasets/anipose.py new file mode 100644 index 0000000000000000000000000000000000000000..ee687898763f278c46ca876b94199b6a0ae24ecf --- /dev/null +++ b/src/stacked_hourglass/datasets/anipose.py @@ -0,0 +1,421 @@ +import gzip +import json +import os +import glob +import random +import math +import numpy as np +import torch +import torch.utils.data as data +from importlib_resources import open_binary +from scipy.io import loadmat +from tabulate import tabulate +import itertools +import json +from scipy import ndimage +import xml.etree.ElementTree as ET + +from csv import DictReader +from pycocotools.mask import decode as decode_RLE + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../')) +# import stacked_hourglass.res +# from stacked_hourglass.datasets.common import DataInfo +from src.configs.anipose_data_info import COMPLETE_DATA_INFO +from src.stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps +from src.stacked_hourglass.utils.misc import to_torch +from src.stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform +import src.stacked_hourglass.datasets.utils_stanext as utils_stanext +from src.stacked_hourglass.utils.visualization import save_input_image_with_keypoints +# from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES + + + +class AniPose(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO + + # Suggested joints to use for average PCK calculations. + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] # don't know ... + + def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only'): + # self.img_folder_mpii = image_path # root image folders + self.is_train = is_train # training set or test set + if do_augment == 'yes': + self.do_augment = True + elif do_augment == 'no': + self.do_augment = False + elif do_augment=='default': + if self.is_train: + self.do_augment = True + else: + self.do_augment = False + else: + raise ValueError + self.inp_res = inp_res + self.out_res = out_res + self.sigma = sigma + self.scale_factor = scale_factor + self.rot_factor = rot_factor + self.label_type = label_type + self.dataset_mode = dataset_mode + if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg': + self.calc_seg = True + else: + self.calc_seg = False + + self.kp_dict = self.keyp_name_to_ind() + + # import pdb; pdb.set_trace() + + self.top_folder = '/ps/scratch/nrueegg/new_projects/Animals/data/animal_pose_dataset/' + self.folder_imgs_0 = '/ps/project/datasets/VOCdevkit/VOC2012/JPEGImages/' + self.folder_imgs_1 = os.path.join(self.top_folder, 'animalpose_image_part2', 'dog') + self.folder_annot_0 = os.path.join(self.top_folder, 'PASCAL2011_animal_annotation', 'dog') + self.folder_annot_1 = os.path.join(self.top_folder, 'animalpose_anno2', 'dog') + all_annot_files_0 = glob.glob(self.folder_annot_0 + '/*.xml') # 1571 + '''all_annot_files_0_raw.sort() + all_annot_files_0 = [] # 1331 + for ind_f, f in enumerate(all_annot_files_0_raw): + name = (f.split('/')[-1]).split('.xml')[0] + name_main = name[:-2] + if ind_f > 0: + if (not name_main == name_main_last) or (ind_f == len(all_annot_files_0_raw)-1): + all_annot_files_0.append(f_last) + f_last = f + name_main_last = name_main''' + all_annot_files_1 = glob.glob(self.folder_annot_1 + '/*.xml') # 200 + all_annot_files = all_annot_files_0 + all_annot_files_1 + + + # old for hg_anipose_v0 + # self.train_name_list = all_annot_files + # self.test_name_list = all_annot_files[0:50] + all_annot_files[200:250] + # new for hg_anipose_v1 + self.train_name_list = all_annot_files[:-50] + self.test_name_list = all_annot_files[-50:] + + '''all_annot_files.sort() + + self.train_name_list = all_annot_files[:24] + self.test_name_list = all_annot_files[24:36]''' + + print('anipose dataset size: ') + print(len(self.train_name_list)) + print(len(self.test_name_list)) + + + # ----------------------------------------- + def read_content(sewlf, xml_file, annot_type='animal_pose'): + # annot_type is either 'animal_pose' or 'animal_pose_voc' or 'voc' + # examples: + # animal_pose: '/ps/scratch/nrueegg/new_projects/Animals/data/animal_pose_dataset/animalpose_anno2/cat/ca137.xml' + # animal_pose_voc: '/ps/scratch/nrueegg/new_projects/Animals/data/animal_pose_dataset/PASCAL2011_animal_annotation/cat/2008_005380_1.xml' + # voc: '/ps/project/datasets/VOCdevkit/VOC2012/Annotations/2011_000192.xml' + if annot_type == 'animal_pose' or annot_type == 'animal_pose_voc': + my_dict = {} + tree = ET.parse(xml_file) + root = tree.getroot() + for child in root: # list + if child.tag == 'image': + my_dict['image'] = child.text + elif child.tag == 'category': + my_dict['category'] = child.text + elif child.tag == 'visible_bounds': + my_dict['visible_bounds'] = child.attrib + elif child.tag == 'keypoints': + n_kp = len(child) + xyzvis = np.zeros((n_kp, 4)) + kp_names = [] + for ind_kp, kp in enumerate(child): # list + xyzvis[ind_kp, 0] = kp.attrib['x'] + xyzvis[ind_kp, 1] = kp.attrib['y'] + xyzvis[ind_kp, 2] = kp.attrib['z'] + xyzvis[ind_kp, 3] = kp.attrib['visible'] + kp_names.append(kp.attrib['name']) + my_dict['keypoints_xyzvis'] = xyzvis + my_dict['keypoints_names'] = kp_names + elif child.tag == 'voc_id': # animal_pose_voc only + my_dict['voc_id'] = child.text + elif child.tag == 'polylinesegments': # animal_pose_voc only + my_dict['polylinesegments'] = child[0].attrib + else: + print('tag does not exist: ' + child.tag) + # print(my_dict) + elif annot_type == 'voc': + my_dict = {} + print('not yet read') + else: + print('this annot_type does not exist') + import pdb; pdb.set_trace() + return my_dict + + + def keyp_name_to_ind(self): + '''AniPose_JOINT_NAMES = [ + 'L_Eye', 'R_Eye', 'Nose', 'L_EarBase', 'Throat', 'R_F_Elbow', 'R_F_Paw', + 'R_B_Paw', 'R_EarBase', 'L_F_Elbow', 'L_F_Paw', 'Withers', 'TailBase', + 'L_B_Paw', 'L_B_Elbow', 'R_B_Elbow', 'L_F_Knee', 'R_F_Knee', 'L_B_Knee', + 'R_B_Knee']''' + kps = self.DATA_INFO.joint_names + kps_dict = {} + for ind_kp, kp in enumerate(kps): + kps_dict[kp] = ind_kp + kps_dict[kp.lower()] = ind_kp + if kp.lower() == 'l_earbase': + kps_dict['l_ear'] = ind_kp + if kp.lower() == 'r_earbase': + kps_dict['r_ear'] = ind_kp + if kp.lower() == 'tailbase': + kps_dict['tail'] = ind_kp + return kps_dict + + + + def __getitem__(self, index): + + # import pdb; pdb.set_trace() + + if self.is_train: + xml_path = self.train_name_list[index] + else: + xml_path = self.test_name_list[index] + + name = (xml_path.split('/')[-1]).split('.xml')[0] + annot_dict = self.read_content(xml_path, annot_type='animal_pose_voc') + + if xml_path.split('/')[-3] == 'PASCAL2011_animal_annotation': + img_path = os.path.join(self.folder_imgs_0, annot_dict['image'] + '.jpg') + keyword_ymin = 'ymin' + else: + # import pdb; pdb.set_trace() + img_path = os.path.join(self.folder_imgs_1, annot_dict['image']) + keyword_ymin = 'xmax' + + '''print(img_path) + print(annot_dict['keypoints_xyzvis'].shape) + print(annot_dict['keypoints_names'])''' + + + + sf = self.scale_factor + rf = self.rot_factor + + + + vis_np = np.zeros((self.DATA_INFO.n_keyp)) + pts_np = np.ones((self.DATA_INFO.n_keyp, 2)) * (-1000) + for ind_key, key in enumerate(annot_dict['keypoints_names']): + key_lower = key.lower() + ind_new = self.kp_dict[key_lower] + vis_np[ind_new] = annot_dict['keypoints_xyzvis'][ind_key, 3] + # remark: the first training run (animalpose_hg8_v0) was without subtracting 1 which would be important! + # pts_np[ind_new] = annot_dict['keypoints_xyzvis'][ind_key, 0:2] + + # what we were doing until 08.09.2022: + pts_np[ind_new] = annot_dict['keypoints_xyzvis'][ind_key, 0:2] - 1 + + # new 08.09.2022 + # pts_np[ind_new] = annot_dict['keypoints_xyzvis'][ind_key, 0:2] + + # pts_np[ind_new] = annot_dict['keypoints_xyzvis'][ind_key, 0:2] # - 1 + + + + '''vis_np = annot_dict['keypoints_xyzvis'][:20, 3] + pts_np = annot_dict['keypoints_xyzvis'][:20, :2] + pts_np[vis_np==0] = -1000''' + + pts_np = np.concatenate((pts_np, vis_np[:, None]), axis=1) + pts = torch.Tensor(pts_np) + + # what we were doing until 08.09.2022: + # bbox_xywh = [float(annot_dict['visible_bounds']['xmin']), float(annot_dict['visible_bounds'][keyword_ymin]), \ + # float(annot_dict['visible_bounds']['width']), float(annot_dict['visible_bounds']['height'])] + bbox_xywh = [float(annot_dict['visible_bounds']['xmin'])-1, float(annot_dict['visible_bounds'][keyword_ymin])-1, \ + float(annot_dict['visible_bounds']['width']), float(annot_dict['visible_bounds']['height'])] + + + + '''pts = torch.Tensor(np.asarray(data['joints'])[:20, :]) + # pts[:, 0:2] -= 1 # Convert pts to zero based + + # inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + # sf = scale * 200.0 / res[0] # res[0]=256 + # center = center * 1.0 / sf + # scale = scale / sf = 256 / 200 + # h = 200 * scale + bbox_xywh = data['img_bbox']''' + + bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] + bbox_max = max(bbox_xywh[2], bbox_xywh[3]) + bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) + # bbox_s = bbox_max / 200. # the dog will fill the image -> bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + + + + + + + + + + # For single-person pose estimation with a centered/scaled figure + nparts = pts.size(0) + img = load_image(img_path) # CxHxW + + # segmentation map (we reshape it to 3xHxW, such that we can do the + # same transformations as with the image) + if self.calc_seg: + raise NotImplementedError + seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) + seg = torch.cat(3*[seg]) + + r = 0 + # self.is_train = False + do_flip = False + if self.do_augment: + s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] + r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 + # Flip + if random.random() <= 0.5: + do_flip = True + img = fliplr(img) + if self.calc_seg: + seg = fliplr(seg) + # pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) + # remark: for BITE we figure out that a -1 was missing in the point mirroring term + # idea: + # image coordinates are 0, 1, 2, 3 + # image size is 4 + # the new point location for former 0 should be 3 and not 4! + pts = shufflelr(pts, img.size(2)-1, self.DATA_INFO.hflip_indices) + c[0] = img.size(2) - c[0] - 1 + # Color + img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + + # Prepare image and groundtruth map + inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + if self.calc_seg: + seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) + + # Generate ground truth + tpts = pts.clone() + target_weight = tpts[:, 2].clone().view(nparts, 1) + + + # cvpr version: + ''' + target = torch.zeros(nparts, self.out_res, self.out_res) + for i in range(nparts): + # if tpts[i, 2] > 0: # This is evil!! + if tpts[i, 1] > 0: + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) + target[i], vis = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) + target_weight[i, 0] *= vis + # NEW: + target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) + target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new + target_new[(target_weight_new==0).reshape((-1)), :, :] = 0 + ''' + + target = torch.zeros(nparts, self.out_res, self.out_res) + for i in range(nparts): + # if tpts[i, 2] > 0: # This is evil!! + '''if tpts[i, 1] > 0: + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2], c, s, [self.out_res, self.out_res], rot=r, as_int=False)) + target[i], vis = draw_labelmap(target[i], tpts[i], self.sigma, type=self.label_type) + target_weight[i, 0] *= vis''' + if tpts[i, 1] > 0: + # this pytorch function (transforms) assumes that coordinates which start at 1 instead of 0! + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) - 1 + target[i], vis = draw_labelmap(target[i], tpts[i], self.sigma, type=self.label_type) + target_weight[i, 0] *= vis + + + + + + + + + + + # Meta info + '''this_breed = self.breed_dict[name.split('/')[0]]''' + + # add information about location within breed similarity matrix + '''folder_name = name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + abbrev = COMPLETE_ABBREV_DICT[breed_name] + try: + sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix + except: # some breeds are not in the xlsx file + sim_breed_index = -1''' + + # meta = {'index' : index, 'center' : c, 'scale' : s, 'do_flip' : do_flip, 'rot' : r, 'resolution' : [self.out_res, self.out_res], 'name' : name, + # 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'breed_index': this_breed['index']} + # meta = {'index' : index, 'center' : c, 'scale' : s, 'do_flip' : do_flip, 'rot' : r, 'resolution' : self.out_res, + # 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'breed_index': this_breed['index']} + # meta = {'index' : index, 'center' : c, 'scale' : s, + # 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + # 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index} + meta = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight} + + # import pdb; pdb.set_trace() + + + + + + + + + if self.dataset_mode=='keyp_only': + ''' + debugging_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/anipose/' + if self.is_train: + prefix = 'anipose_train_' + else: + prefix = 'anipose_test_' + save_input_image_with_keypoints(inp, meta['tpts'], out_path=debugging_path + prefix + str(index) + '.png', ratio_in_out=self.inp_res/self.out_res) + ''' + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg': + raise NotImplementedError + meta['silh'] = seg[0, :, :] + meta['name'] = name + return inp, target, meta + elif self.dataset_mode=='complete': + raise NotImplementedError + target_dict = meta + target_dict['silh'] = seg[0, :, :] + # NEW for silhouette loss + distmat_tofg = ndimage.distance_transform_edt(1-target_dict['silh']) # values between 0 and up to 100 or more + target_dict['silh_distmat_tofg'] = distmat_tofg + distmat_tobg = ndimage.distance_transform_edt(target_dict['silh']) + target_dict['silh_distmat_tobg'] = distmat_tobg + return inp, target_dict + else: + raise ValueError + + + + def __len__(self): + if self.is_train: + return len(self.train_name_list) # len(self.train_list) + else: + return len(self.test_name_list) # len(self.valid_list) + + diff --git a/src/stacked_hourglass/datasets/dogsvoc.py b/src/stacked_hourglass/datasets/dogsvoc.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba0ed31e568fe610a4117c01ac2e6e30f5a0960 --- /dev/null +++ b/src/stacked_hourglass/datasets/dogsvoc.py @@ -0,0 +1,376 @@ +# 24 joints instead of 20!! + + +import gzip +import json +import os +import random +import math +import numpy as np +import torch +import torch.utils.data as data +from importlib_resources import open_binary +from scipy.io import loadmat +from tabulate import tabulate +import itertools +import json +from scipy import ndimage + +from csv import DictReader +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) +# import stacked_hourglass.res +# from stacked_hourglass.datasets.common import DataInfo +# from configs.data_info import COMPLETE_DATA_INFO +# from configs.anipose_data_info import COMPLETE_DATA_INFO_24 +from src.configs.data_info import COMPLETE_DATA_INFO_24 +from src.stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps +from src.stacked_hourglass.utils.misc import to_torch +from src.stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform +import src.stacked_hourglass.datasets.utils_stanext as utils_stanext +from src.stacked_hourglass.utils.visualization import save_input_image_with_keypoints + + + +class DogsVOC(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + + # Suggested joints to use for average PCK calculations. + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] # don't know ... + + def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None): + # self.img_folder_mpii = image_path # root image folders + self.V12 = V12 + self.is_train = is_train # training set or test set + if do_augment == 'yes': + self.do_augment = True + elif do_augment == 'no': + self.do_augment = False + elif do_augment=='default': + if self.is_train: + self.do_augment = True + else: + self.do_augment = False + else: + raise ValueError + self.inp_res = inp_res + self.out_res = out_res + self.sigma = sigma + self.scale_factor = scale_factor + self.rot_factor = rot_factor + self.label_type = label_type + self.dataset_mode = dataset_mode + if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': + self.calc_seg = True + else: + self.calc_seg = False + + # create train/val split + # REMARK: I assume we should have a different train / test split here + self.img_folder = utils_stanext.get_img_dir(V12=self.V12) + self.train_dict, self.test_dict, self.val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + self.train_name_list = list(self.train_dict.keys()) # 7004 + self.test_name_list = list(self.test_dict.keys()) # 5031 + + # breed json_path + breed_json_path = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra/StanExt_breed_dict_v2.json' + + # only use images that show fully visible dogs in standing or walking poses + '''path_easy_images_list = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra/AMT_StanExt_easy_images.txt' + easy_images_list = [line.rstrip('\n') for line in open(path_easy_images_list)] + self.train_name_list = sorted(list(set(easy_images_list) & set(self.train_name_list))) + self.test_name_list = sorted(list(set(easy_images_list) & set(self.test_name_list)))''' + self.train_name_list = sorted(self.train_name_list) + self.test_name_list = sorted(self.test_name_list) + + random.seed(4) + random.shuffle(self.train_name_list) + random.shuffle(self.test_name_list) + + + if shorten_dataset_to is not None: + self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] + self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] + + if shorten_dataset_to == 12: + # my_sample = self.test_name_list[2] # black haired dog + my_sample = self.test_name_list[2] + for ind in range(0, 12): + self.test_name_list[ind] = my_sample + + # add results for eyes, whithers and throat as obtained through anipose + self.path_anipose_out_root = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra/animalpose_hg8_v0_results_on_StanExt/' + + + ############################################### + + self.dogvoc_path_root = '/ps/scratch/nrueegg/new_projects/Animals/data/pascal_voc_parts/' + self.dogvoc_path_images = self.dogvoc_path_root + 'dog_images/' + self.dogvoc_path_masks = self.dogvoc_path_root + 'dog_masks/' + + with open(self.dogvoc_path_masks + 'voc_dogs_bodypart_info.json', 'r') as file: + self.body_part_info = json.load(file) + with open(self.dogvoc_path_masks + 'voc_dogs_train.json', 'r') as file: + train_set_init = json.load(file) # 707 + with open(self.dogvoc_path_masks + 'voc_dogs_val.json', 'r') as file: + val_set_init = json.load(file) # 709 + self.train_set = train_set_init + val_set_init[:-36] + self.val_set = val_set_init[-36:] + + print('len(dataset): ' + str(self.__len__())) + # print(self.test_name_list[0:10]) + + def get_body_part_indices(self): + silh = [ + ('background', [0]), + ('foreground', [255, 21, 57, 30, 59, 34, 48, 50, 79, 49, 61, 60, 54, 53, 36, 35, 27, 26, 78])] + full_body = [ + ('other', [255]), + ('head', [21, 57, 30, 59, 34, 48, 50]), + ('torso', [79, 49]), + ('right front leg', [61, 60]), + ('right back leg', [54, 53]), + ('left front leg', [36, 35]), + ('left back leg', [27, 26]), + ('tail', [78])] + head = [ + ('other', [21, 59, 34]), + ('right ear', [57]), + ('left ear', [30]), + ('muzzle', [48]), + ('nose', [50])] + torso = [ + ('other', [79]), # wrong 34 + ('neck', [49])] + all_parts = { + 'silh': silh, + 'full_body': full_body, + 'head': head, + 'torso': torso} + return all_parts + + + + + + def __getitem__(self, index): + + if self.is_train: + name = self.train_name_list[index] + data = self.train_dict[name] + # data = utils_stanext.get_dog(self.train_dict, name) + else: + name = self.test_name_list[index] + data = self.test_dict[name] + # data = utils_stanext.get_dog(self.test_dict, name) + + # self.do_augment = False + + # index = 5 ########################## + if self.is_train: + img_info = self.train_set[index] + else: + img_info = self.val_set[index] + + sf = self.scale_factor + rf = self.rot_factor + + img_path = os.path.join(self.dogvoc_path_images, img_info['img_name']) + + # bbox_yxhw = img_info['bbox'] + # bbox_xywh = [bbox_yxhw[1], bbox_yxhw[0], bbox_yxhw[2], bbox_yxhw[3]] + bbox_xywh = img_info['bbox'] + bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] + bbox_max = max(bbox_xywh[2], bbox_xywh[3]) + bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) + # bbox_s = bbox_max / 200. # the dog will fill the image -> bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + + # For single-person pose estimation with a centered/scaled figure + img = load_image(img_path) # CxHxW + + # img_test = img[0, img_info['bbox'][1]:img_info['bbox'][1]+img_info['bbox'][3], img_info['bbox'][0]:img_info['bbox'][0]+img_info['bbox'][2]] + # import cv2 + # cv2.imwrite('/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/yy.png', np.asarray(img_test*255, np.uint8)) + + + # segmentation map (we reshape it to 3xHxW, such that we can do the + # same transformations as with the image) + if self.do_augment and (random.random() <= 0.5): + do_flip = True + else: + do_flip = False + + if self.calc_seg: + mask = np.load(os.path.join(self.dogvoc_path_masks, img_info['img_name'].split('.')[0] + '_' + str(img_info['ind_bbox']) + '.npz.npy')) + seg_np = mask.copy() + seg_np[mask==0] = 0 + seg_np[mask>0] = 1 + seg = torch.Tensor(seg_np[None, :, :]) + seg = torch.cat(3*[seg]) + + # NEW: body parts + all_parts = self.get_body_part_indices() + body_part_index_list = [] + body_part_name_list = [] + n_tbp = 3 + n_bp = 15 + # body_part_matrix_multiple_hot = np.zeros((n_bp, mask.shape[0], mask.shape[1])) + body_part_matrix_np = np.ones((n_tbp, mask.shape[0], mask.shape[1])) * (-1) + ind_bp = 0 + for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): + # import pdb; pdb.set_trace() + if part == 'full_body': + inds_mirr = [0, 1, 2, 5, 6, 3, 4, 7] + elif part == 'head': + inds_mirr = [0, 2, 1, 3, 4] + else: + inds_mirr = [0, 1] + for ind_sbp, subpart in enumerate(all_parts[part]): + if do_flip: + ind_sbp_corr = inds_mirr[ind_sbp] # we use this if the image is mirrored later on + else: + ind_sbp_corr = ind_sbp + bp_name = subpart[0] + bp_indices = subpart[1] + body_part_index_list.append(bp_indices) + body_part_name_list.append(bp_name) + # create matrix slice + xx = [mask==ind for ind in bp_indices] + xx_mat = (np.stack(xx).sum(axis=0)) + # body_part_matrix_multiple_hot[ind_bp, :, :] = xx_mat + # add to matrix + body_part_matrix_np[ind_tbp, xx_mat>0] = ind_sbp_corr + ind_bp += 1 + body_part_weight_masks_np = np.zeros((n_tbp, mask.shape[0], mask.shape[1])) + body_part_weight_masks_np[0, mask>0] = 1 # full body + body_part_weight_masks_np[1, body_part_matrix_np[0, :, :]==1] = 1 # head + body_part_weight_masks_np[2, body_part_matrix_np[0, :, :]==2] = 1 # torso + body_part_matrix_np[body_part_weight_masks_np==0] = 16 + body_part_matrix = torch.Tensor(body_part_matrix_np + 2.0) # / 100 + + # import pdb; pdb.set_trace() + + bbox_c_int0 = [int(bbox_c[0]), int(bbox_c[1])] + bbox_c_int1 = [int(bbox_c[0])+10, int(bbox_c[1])+10] + '''bpm_c0 = body_part_matrix[:, bbox_c_int0[1], bbox_c_int0[0]].clone() + bpm_c1 = body_part_matrix[:, bbox_c_int1[1], bbox_c_int1[0]].clone() + zero_replacement = torch.Tensor([0, 0, 0.99]) + body_part_matrix[:, bbox_c_int0[1], bbox_c_int0[0]] = zero_replacement + body_part_matrix[:, bbox_c_int1[1], bbox_c_int1[0]] = 1''' + ii = 3 + bpm_c0 = body_part_matrix[2, bbox_c_int0[1]-ii:bbox_c_int0[1]+ii, bbox_c_int0[0]-ii:bbox_c_int0[0]+ii] + bpm_c1 = body_part_matrix[2, bbox_c_int1[1]-ii:bbox_c_int1[1]+ii, bbox_c_int1[0]-ii:bbox_c_int1[0]+ii] + body_part_matrix[2, bbox_c_int0[1]-ii:bbox_c_int0[1]+ii, bbox_c_int0[0]-ii:bbox_c_int0[0]+ii] = 0 + body_part_matrix[2, bbox_c_int1[1]-ii:bbox_c_int1[1]+ii, bbox_c_int1[0]-ii:bbox_c_int1[0]+ii] = 255 + body_part_matrix = (body_part_matrix).long() + # body_part_name_list + # ['other', 'head', 'torso', 'right front leg', 'right back leg', 'left front leg', 'left back leg', 'tail', 'other', 'right ear', 'left ear', 'muzzle', 'nose', 'other', 'neck'] + # swap indices: + # bp_mirroring_inds = [0, 1, 2, 5, 6, 3, 4, 7, 8, 10, 9, 11, 12, 13, 14] + + + r = 0 + # self.is_train = False + if self.do_augment: + s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] + r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 + # Flip + if do_flip: + img = fliplr(img) + if self.calc_seg: + seg = fliplr(seg) + body_part_matrix = fliplr(body_part_matrix) + c[0] = img.size(2) - c[0] + # Color + img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + + # Prepare image and groundtruth map + inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + + # import pdb; pdb.set_trace() + + if self.calc_seg: + seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) + + # 'crop' will divide by 255 and perform zero padding ( + # -> weird function that tries to rescale! Because of that I add zeros and ones in the beginning + xx = body_part_matrix.clone() + + # import pdb; pdb.set_trace() + + + body_part_matrix = crop(body_part_matrix, c, s, [self.inp_res, self.inp_res], rot=r, interp='nearest') + + body_part_matrix = body_part_matrix*255 - 2 + + body_part_matrix[body_part_matrix == -2] = -1 + body_part_matrix[body_part_matrix == 16] = -1 + body_part_matrix[body_part_matrix == 253] = -1 + + '''print(np.unique(body_part_matrix.numpy())) + print(np.unique(body_part_matrix[0, :, :].numpy())) + print(np.unique(body_part_matrix[1, :, :].numpy())) + print(np.unique(body_part_matrix[2, :, :].numpy()))''' + + # import cv2 + # cv2.imwrite('/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/yy2.png', np.asarray((inp[0, :, :]+1)*100, np.uint8)) + # cv2.imwrite('/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/yy3.png', (40*(1+body_part_matrix[0, :, :].numpy())).astype(np.uint8)) + + + + # Generate ground truth + nparts = 24 + target_weight = torch.zeros(nparts, 1) + target = torch.zeros(nparts, self.out_res, self.out_res) + pts = torch.zeros((nparts, 3)) + tpts = torch.zeros((nparts, 3)) + + # import pdb; pdb.set_trace() + + + # meta = {'index' : index, 'center' : c, 'scale' : s, 'do_flip' : do_flip, 'rot' : r, 'resolution' : [self.out_res, self.out_res], 'name' : name, + # 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'breed_index': this_breed['index']} + # meta = {'index' : index, 'center' : c, 'scale' : s, 'do_flip' : do_flip, 'rot' : r, 'resolution' : self.out_res, + # 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, 'breed_index': this_breed['index']} + # meta = {'index' : index, 'center' : c, 'scale' : s, + # 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + # 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, + # 'ind_dataset': 0} # ind_dataset: 0 for stanext or stanexteasy or stanext 24 + meta = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'ind_dataset': 3} + + #import pdb; pdb.set_trace() + + + if self.dataset_mode=='keyp_and_seg_and_partseg': + # meta = {} + meta['silh'] = seg[0, :, :] + meta['name'] = name + meta['body_part_matrix'] = body_part_matrix.long() + # meta['body_part_weights'] = body_part_weight_masks + # import pdb; pdb.set_trace() + return inp, target, meta + else: + raise ValueError + + + + def __len__(self): + if self.is_train: + return len(self.train_set) # len(self.train_list) + else: + return len(self.val_set) # len(self.valid_list) + + diff --git a/src/stacked_hourglass/datasets/imgcrops.py b/src/stacked_hourglass/datasets/imgcrops.py new file mode 100644 index 0000000000000000000000000000000000000000..89face653c8d6c92fb4bf453a1ae46957ee68dff --- /dev/null +++ b/src/stacked_hourglass/datasets/imgcrops.py @@ -0,0 +1,77 @@ + + +import os +import glob +import numpy as np +import torch +import torch.utils.data as data + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.anipose_data_info import COMPLETE_DATA_INFO +from stacked_hourglass.utils.imutils import load_image +from stacked_hourglass.utils.transforms import crop, color_normalize +from stacked_hourglass.utils.pilutil import imresize +from stacked_hourglass.utils.imutils import im_to_torch +from configs.dataset_path_configs import TEST_IMAGE_CROP_ROOT_DIR +from configs.data_info import COMPLETE_DATA_INFO_24 + + +class ImgCrops(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, img_crop_folder='default', image_path=None, is_train=False, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only'): + assert is_train == False + assert do_augment == 'default' or do_augment == False + self.inp_res = inp_res + if img_crop_folder == 'default': + self.folder_imgs = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'datasets', 'test_image_crops') + else: + self.folder_imgs = img_crop_folder + name_list = glob.glob(os.path.join(self.folder_imgs, '*.png')) + glob.glob(os.path.join(self.folder_imgs, '*.jpg')) + glob.glob(os.path.join(self.folder_imgs, '*.jpeg')) + name_list = sorted(name_list) + self.test_name_list = [name.split('/')[-1] for name in name_list] + print('len(dataset): ' + str(self.__len__())) + + def __getitem__(self, index): + img_name = self.test_name_list[index] + # load image + img_path = os.path.join(self.folder_imgs, img_name) + img = load_image(img_path) # CxHxW + # prepare image (cropping and color) + img_max = max(img.shape[1], img.shape[2]) + img_padded = torch.zeros((img.shape[0], img_max, img_max)) + if img_max == img.shape[2]: + start = (img_max-img.shape[1])//2 + img_padded[:, start:start+img.shape[1], :] = img + else: + start = (img_max-img.shape[2])//2 + img_padded[:, :, start:start+img.shape[2]] = img + img = img_padded + img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear')) + inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + # add the following fields to make it compatible with stanext, most of them are fake + target_dict = {'index': index, 'center' : -2, 'scale' : -2, + 'breed_index': -2, 'sim_breed_index': -2, + 'ind_dataset': 1} + target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1)) + target_dict['silh'] = np.zeros((self.inp_res, self.inp_res)) + return inp, target_dict + + + def __len__(self): + return len(self.test_name_list) + + + + + + + + + diff --git a/src/stacked_hourglass/datasets/imgcropslist.py b/src/stacked_hourglass/datasets/imgcropslist.py new file mode 100644 index 0000000000000000000000000000000000000000..27b25d7cc8e54e7e0d9492be0c493bbcac20f173 --- /dev/null +++ b/src/stacked_hourglass/datasets/imgcropslist.py @@ -0,0 +1,85 @@ + + +import os +import glob +import numpy as np +import math +import torch +import torch.utils.data as data + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.anipose_data_info import COMPLETE_DATA_INFO +from stacked_hourglass.utils.imutils import load_image, im_to_torch +from stacked_hourglass.utils.transforms import crop, color_normalize +from stacked_hourglass.utils.pilutil import imresize +from stacked_hourglass.utils.imutils import im_to_torch +from configs.data_info import COMPLETE_DATA_INFO_24 + + +class ImgCrops(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, image_list, bbox_list=None, inp_res=256, dataset_mode='keyp_only'): + # the list contains the images directly, not only their paths + self.image_list = image_list + self.bbox_list = bbox_list + self.inp_res = inp_res + self.test_name_list = [] + for ind in np.arange(0, len(self.image_list)): + self.test_name_list.append(str(ind)) + print('len(dataset): ' + str(self.__len__())) + + def __getitem__(self, index): + + # load image + img = im_to_torch(self.image_list[index]) + + # try loading bounding box + if self.bbox_list is not None: + bbox = self.bbox_list[index] + bbox_xywh = [bbox[0][0], bbox[0][1], bbox[1][0]-bbox[0][0], bbox[1][1]-bbox[0][1]] + bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] + bbox_max = max(bbox_xywh[2], bbox_xywh[3]) + bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + img_prep = crop(img, c, s, [self.inp_res, self.inp_res], rot=0) + else: + # prepare image (cropping and color) + img_max = max(img.shape[1], img.shape[2]) + img_padded = torch.zeros((img.shape[0], img_max, img_max)) + if img_max == img.shape[2]: + start = (img_max-img.shape[1])//2 + img_padded[:, start:start+img.shape[1], :] = img + else: + start = (img_max-img.shape[2])//2 + img_padded[:, :, start:start+img.shape[2]] = img + img = img_padded + img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear')) + + inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + # add the following fields to make it compatible with stanext, most of them are fake + target_dict = {'index': index, 'center' : -2, 'scale' : -2, + 'breed_index': -2, 'sim_breed_index': -2, + 'ind_dataset': 1} + target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1)) + target_dict['silh'] = np.zeros((self.inp_res, self.inp_res)) + return inp, target_dict + + + def __len__(self): + return len(self.image_list) + + + + + + + + + diff --git a/src/stacked_hourglass/datasets/samplers/__init__.py b/src/stacked_hourglass/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/stacked_hourglass/datasets/samplers/custom_gc_sampler.py b/src/stacked_hourglass/datasets/samplers/custom_gc_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9dccab1384d515ddd0aa45e56213cc2b1d84a670 --- /dev/null +++ b/src/stacked_hourglass/datasets/samplers/custom_gc_sampler.py @@ -0,0 +1,197 @@ + +import numpy as np +import random +import copy +import time +import warnings +import random + +from torch.utils.data import Sampler +from torch._six import int_classes as _int_classes + +class CustomGCSampler(Sampler): + """Wraps another sampler to yield a mini-batch of indices. + The structure of this sampler is way to complicated because it is a shorter/simplified version of + CustomBatchSampler. The relations between breeds are not relevant for the cvpr 2022 paper, but we kept + this structure which we were using for the experiments with clade related losses. ToDo: restructure + this sampler. + Args: + data_sampler_info (dict): a dictionnary, containing information about the dataset and breeds. + batch_size (int): Size of mini-batch. + """ + + def __init__(self, data_sampler_info_gc, batch_size, add_nonflat=False, more_standing=False): + if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ + batch_size <= 0: + assert (batch_size == 12 and add_nonflat==False) or (batch_size == 14 and add_nonflat==True) + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + self.data_sampler_info_gc = data_sampler_info_gc + self.batch_size = batch_size + self.add_nonflat = add_nonflat + self.more_standing = more_standing + + self.n_images_tot = len(self.data_sampler_info_gc['name_list']) # 4305 + + # get full sorted image list + self.pose_dict = {} + self.dict_name_to_idx = {} + for ind_img, img in enumerate(self.data_sampler_info_gc['name_list']): + self.dict_name_to_idx[img] = ind_img + pose = self.data_sampler_info_gc['gc_annots_categories'][img]['pose'] + if pose in self.pose_dict.keys(): + self.pose_dict[pose].append(img) + else: + self.pose_dict[pose] = [img] + + # prepare non-flat images + if self.add_nonflat: + self.n_images_nonflat_tot = len(self.data_sampler_info_gc['name_list_nonflat']) + + # self.n_desired_batches = int(np.floor(len(self.data_sampler_info_gc['name_list']) / batch_size)) # 157 + self.n_desired_batches = int(np.ceil(len(self.get_list_for_group_index(ind_g=1, n_groups=5, shuffle=True, more_standing=self.more_standing)) / 3)) + + def get_description(self): + description = "\ + This sampler returns stanext data such that poses are more balanced. \n\ + -> works on top of stanext24_withgc_v2" + return description + + def get_nonflat_idx_list(self, shuffle=True): + all_nonflat_idxs = list(range(self.n_images_tot, self.n_images_tot + self.n_images_nonflat_tot)) + if shuffle: + random.shuffle(all_nonflat_idxs) + return all_nonflat_idxs + + def get_list_for_group_index(self, ind_g, n_groups=5, shuffle=True, return_info=False, more_standing=False): + # availabe poses + # sitting_sym: 561 + # lying_sym: 199 + # jumping_touching: 21 + # standing_4paws: 1999 + # running: 132 + # sitting_comp: 306 + # onhindlegs: 16 + # walking: 325 + # lying_comp: 596 + # standing_fewpaws: 98 + # otherpose: 22 + # downwardfacingdog: 14 + # jumping_nottouching: 16 + # + # available groups (7 groups) + # 89: 'otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching' + # 561: 'sitting_sym' + # 306: 'sitting_comp' + # 199: 'lying_sym' + # 596: 'lying_comp' + # 555: 'standing_fewpaws', 'running', 'walking' + # 1999: 'standing_4paws' + # -> sample: 2, 1.5, 1.5, 1.5, 1.5, 2, 2 + # + # available groups (5 groups) + # 89: 'otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching' + # 867: 'sitting_sym', 'sitting_comp' + # 795: 'lying_sym', 'lying_comp' + # 555: 'standing_fewpaws', 'running', 'walking' + # 1999: 'standing_4paws' + # -> sample: 2, 3, 3, 2, 2 + assert (n_groups == 5) + if more_standing: + if ind_g == 0: + n_samples_per_batch = 2 + pose_names = ['otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching'] + elif ind_g == 1: + n_samples_per_batch = 2 + pose_names = ['sitting_sym', 'sitting_comp'] + elif ind_g == 2: + n_samples_per_batch = 2 + pose_names = ['lying_sym', 'lying_comp'] + elif ind_g == 3: + n_samples_per_batch = 2 + pose_names = ['standing_fewpaws', 'running', 'walking'] + elif ind_g == 4: + n_samples_per_batch = 4 + pose_names = ['standing_4paws'] + else: + raise ValueError + else: + if ind_g == 0: + n_samples_per_batch = 2 + pose_names = ['otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching'] + elif ind_g == 1: + n_samples_per_batch = 3 + pose_names = ['sitting_sym', 'sitting_comp'] + elif ind_g == 2: + n_samples_per_batch = 3 + pose_names = ['lying_sym', 'lying_comp'] + elif ind_g == 3: + n_samples_per_batch = 2 + pose_names = ['standing_fewpaws', 'running', 'walking'] + elif ind_g == 4: + n_samples_per_batch = 2 + pose_names = ['standing_4paws'] + else: + raise ValueError + all_imgs_this_group = [] + for pose_name in pose_names: + all_imgs_this_group.extend(self.pose_dict[pose_name]) + if shuffle: + random.shuffle(all_imgs_this_group) + if return_info: + return all_imgs_this_group, pose_names, n_samples_per_batch + else: + return all_imgs_this_group + + + def __iter__(self): + + n_groups = 5 + group_lists = {} + n_samples_per_batch = {} + for ind_g in range(n_groups): + group_lists[ind_g], pose_names, n_samples_per_batch[ind_g] = self.get_list_for_group_index(ind_g, n_groups=5, shuffle=True, return_info=True, more_standing=self.more_standing) + if self.add_nonflat: + nonflat_idx_list = self.get_nonflat_idx_list() + + # we want to sample all sitting poses at least once per batch (and ths all other + # images except standing on 4 paws) + all_batches = [] + for ind in range(self.n_desired_batches): + batch_with_idxs = [] + for ind_g in range(n_groups): + for ind_s in range(n_samples_per_batch[ind_g]): + if len(group_lists[ind_g]) == 0: + group_lists[ind_g] = self.get_list_for_group_index(ind_g, n_groups=5, shuffle=True, more_standing=self.more_standing) + name = group_lists[ind_g].pop(0) + idx = self.dict_name_to_idx[name] + batch_with_idxs.append(idx) + if self.add_nonflat: + for ind_x in range(2): + if len(nonflat_idx_list) == 0: + nonflat_idx_list = self.get_nonflat_idx_list() + idx = nonflat_idx_list.pop(0) + batch_with_idxs.append(idx) + all_batches.append(batch_with_idxs) + + for batch in all_batches: + yield batch + + + def __len__(self): + # Since we are sampling pairs of dogs and not each breed has an even number of dogs, we can not + # guarantee to show each dog exacly once. What we do instead, is returning the same amount of + # batches as we would return with a standard sampler which is not based on dog pairs. + '''if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' + return self.n_desired_batches + + + + + + + + diff --git a/src/stacked_hourglass/datasets/samplers/custom_gc_sampler_noclasses.py b/src/stacked_hourglass/datasets/samplers/custom_gc_sampler_noclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..3887502b350730f1e32a7ac049dbc07d383cbcd8 --- /dev/null +++ b/src/stacked_hourglass/datasets/samplers/custom_gc_sampler_noclasses.py @@ -0,0 +1,163 @@ + +import numpy as np +import random +import copy +import time +import warnings +import random + +from torch.utils.data import Sampler +from torch._six import int_classes as _int_classes + +class CustomGCSamplerNoCLass(Sampler): + """Wraps another sampler to yield a mini-batch of indices. + The structure of this sampler is way to complicated because it is a shorter/simplified version of + CustomBatchSampler. The relations between breeds are not relevant for the cvpr 2022 paper, but we kept + this structure which we were using for the experiments with clade related losses. ToDo: restructure + this sampler. + Args: + data_sampler_info (dict): a dictionnary, containing information about the dataset and breeds. + batch_size (int): Size of mini-batch. + """ + + def __init__(self, data_sampler_info_gc, batch_size, add_nonflat=False): + if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ + batch_size <= 0: + assert (batch_size == 12 and add_nonflat==False) or (batch_size == 14 and add_nonflat==True) + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + self.data_sampler_info_gc = data_sampler_info_gc + self.batch_size = batch_size + self.add_nonflat = add_nonflat + + self.n_images_tot = len(self.data_sampler_info_gc['name_list']) # 4305 + + # get full sorted image list + self.pose_dict = {} + self.dict_name_to_idx = {} + for ind_img, img in enumerate(self.data_sampler_info_gc['name_list']): + self.dict_name_to_idx[img] = ind_img + pose = self.data_sampler_info_gc['gc_annots_categories'][img]['pose'] + if pose in self.pose_dict.keys(): + self.pose_dict[pose].append(img) + else: + self.pose_dict[pose] = [img] + + # prepare non-flat images + if self.add_nonflat: + self.n_images_nonflat_tot = len(self.data_sampler_info_gc['name_list_nonflat']) + + # self.n_desired_batches = int(np.floor(len(self.data_sampler_info_gc['name_list']) / batch_size)) # 157 + self.n_desired_batches = 160 + + def get_description(self): + description = "\ + This sampler returns stanext data such that poses are more balanced. \n\ + -> works on top of stanext24_withgc_v2" + return description + + def get_nonflat_idx_list(self, shuffle=True): + all_nonflat_idxs = list(range(self.n_images_tot, self.n_images_tot + self.n_images_nonflat_tot)) + if shuffle: + random.shuffle(all_nonflat_idxs) + return all_nonflat_idxs + + def get_list_for_group_index(self, ind_g, n_groups=1, shuffle=True, return_info=False): + # availabe poses + # sitting_sym: 561 + # lying_sym: 199 + # jumping_touching: 21 + # standing_4paws: 1999 + # running: 132 + # sitting_comp: 306 + # onhindlegs: 16 + # walking: 325 + # lying_comp: 596 + # standing_fewpaws: 98 + # otherpose: 22 + # downwardfacingdog: 14 + # jumping_nottouching: 16 + # + # available groups (7 groups) + # 89: 'otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching' + # 561: 'sitting_sym' + # 306: 'sitting_comp' + # 199: 'lying_sym' + # 596: 'lying_comp' + # 555: 'standing_fewpaws', 'running', 'walking' + # 1999: 'standing_4paws' + # -> sample: 2, 1.5, 1.5, 1.5, 1.5, 2, 2 + # + # available groups (5 groups) + # 89: 'otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching' + # 867: 'sitting_sym', 'sitting_comp' + # 795: 'lying_sym', 'lying_comp' + # 555: 'standing_fewpaws', 'running', 'walking' + # 1999: 'standing_4paws' + # -> sample: 2, 3, 3, 2, 2 + assert (n_groups == 1) + if ind_g == 0: + n_samples_per_batch = 12 + pose_names = ['otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching', 'sitting_sym', 'sitting_comp', 'lying_sym', 'lying_comp', 'standing_fewpaws', 'running', 'walking', 'standing_4paws'] + all_imgs_this_group = [] + for pose_name in pose_names: + all_imgs_this_group.extend(self.pose_dict[pose_name]) + if shuffle: + random.shuffle(all_imgs_this_group) + if return_info: + return all_imgs_this_group, pose_names, n_samples_per_batch + else: + return all_imgs_this_group + + + def __iter__(self): + + n_groups = 1 + group_lists = {} + n_samples_per_batch = {} + for ind_g in range(n_groups): + group_lists[ind_g], pose_names, n_samples_per_batch[ind_g] = self.get_list_for_group_index(ind_g, n_groups=1, shuffle=True, return_info=True) + if self.add_nonflat: + nonflat_idx_list = self.get_nonflat_idx_list() + + # we want to sample all sitting poses at least once per batch (and ths all other + # images except standing on 4 paws) + all_batches = [] + for ind in range(self.n_desired_batches): + batch_with_idxs = [] + for ind_g in range(n_groups): + for ind_s in range(n_samples_per_batch[ind_g]): + if len(group_lists[ind_g]) == 0: + group_lists[ind_g] = self.get_list_for_group_index(ind_g, n_groups=1, shuffle=True) + name = group_lists[ind_g].pop(0) + idx = self.dict_name_to_idx[name] + batch_with_idxs.append(idx) + if self.add_nonflat: + for ind_x in range(2): + if len(nonflat_idx_list) == 0: + nonflat_idx_list = self.get_nonflat_idx_list() + idx = nonflat_idx_list.pop(0) + batch_with_idxs.append(idx) + all_batches.append(batch_with_idxs) + + for batch in all_batches: + yield batch + + + def __len__(self): + # Since we are sampling pairs of dogs and not each breed has an even number of dogs, we can not + # guarantee to show each dog exacly once. What we do instead, is returning the same amount of + # batches as we would return with a standard sampler which is not based on dog pairs. + '''if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' + return self.n_desired_batches + + + + + + + + diff --git a/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py b/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb8a636d1138a58cd2265f931e2c19ef47a9220 --- /dev/null +++ b/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py @@ -0,0 +1,171 @@ + +import numpy as np +import random +import copy +import time +import warnings + +from torch.utils.data import Sampler +from torch._six import int_classes as _int_classes + +class CustomPairBatchSampler(Sampler): + """Wraps another sampler to yield a mini-batch of indices. + The structure of this sampler is way to complicated because it is a shorter/simplified version of + CustomBatchSampler. The relations between breeds are not relevant for the cvpr 2022 paper, but we kept + this structure which we were using for the experiments with clade related losses. ToDo: restructure + this sampler. + Args: + data_sampler_info (dict): a dictionnary, containing information about the dataset and breeds. + batch_size (int): Size of mini-batch. + """ + + def __init__(self, data_sampler_info, batch_size): + if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + assert batch_size%2 == 0 + self.data_sampler_info = data_sampler_info + self.batch_size = batch_size + self.n_desired_batches = int(np.floor(len(self.data_sampler_info['name_list']) / batch_size)) # 157 + + def get_description(self): + description = "\ + This sampler works only for even batch sizes. \n\ + It returns pairs of dogs of the same breed" + return description + + + def __iter__(self): + breeds_summary = self.data_sampler_info['breeds_summary'] + + breed_image_dict_orig = {} + for img_name in self.data_sampler_info['name_list']: # ['n02093859-Kerry_blue_terrier/n02093859_913.jpg', ... ] + folder_name = img_name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + if not (breed_name in breed_image_dict_orig): + breed_image_dict_orig[breed_name] = [img_name] + else: + breed_image_dict_orig[breed_name].append(img_name) + + lengths = np.zeros((len(breed_image_dict_orig.values()))) + for ind, value in enumerate(breed_image_dict_orig.values()): + lengths[ind] = len(value) + + sim_matrix_raw = self.data_sampler_info['breeds_sim_martix_raw'] + sim_matrix_raw[sim_matrix_raw>0].shape # we have 1061 connections + + # from ind_in_sim_mat to breed_name + inverse_sim_dict = {} + for abbrev, ind in self.data_sampler_info['breeds_sim_abbrev_inds'].items(): + # breed_name might be None + breed = breeds_summary[abbrev] + breed_name = breed._name_stanext + inverse_sim_dict[ind] = {'abbrev': abbrev, + 'breed_name': breed_name} + + # similarity for relevant breeds only: + related_breeds_top_orig = {} + temp = np.arange(sim_matrix_raw.shape[0]) + for breed_name, breed_images in breed_image_dict_orig.items(): + abbrev = self.data_sampler_info['breeds_abbrev_dict'][breed_name] + related_breeds = {} + if abbrev in self.data_sampler_info['breeds_sim_abbrev_inds'].keys(): + ind_in_sim_mat = self.data_sampler_info['breeds_sim_abbrev_inds'][abbrev] + row = sim_matrix_raw[ind_in_sim_mat, :] + rel_inds = temp[row>0] + for ind in rel_inds: + rel_breed_name = inverse_sim_dict[ind]['breed_name'] + rel_abbrev = inverse_sim_dict[ind]['abbrev'] + # does this breed exist in this dataset? + if (rel_breed_name is not None) and (rel_breed_name in breed_image_dict_orig.keys()) and not (rel_breed_name==breed_name): + related_breeds[rel_breed_name] = row[ind] + related_breeds_top_orig[breed_name] = related_breeds + + breed_image_dict = copy.deepcopy(breed_image_dict_orig) + related_breeds_top = copy.deepcopy(related_breeds_top_orig) + + # clean the related_breeds_top dict such that it only contains breeds which are available + for breed_name, breed_images in breed_image_dict.items(): + if len(breed_image_dict[breed_name]) < 1: + for breed_name_rel in list(related_breeds_top[breed_name].keys()): + related_breeds_top[breed_name_rel].pop(breed_name, None) + related_breeds_top[breed_name].pop(breed_name_rel, None) + + # 1) build pairs of dogs + set_of_breeds_with_at_least_2 = set() + for breed_name, breed_images in breed_image_dict.items(): + if len(breed_images) >= 2: + set_of_breeds_with_at_least_2.add(breed_name) + + n_unused_images = len(self.data_sampler_info['name_list']) + all_dog_duos = [] + n_new_duos = 1 + while n_new_duos > 0: + for breed_name, breed_images in breed_image_dict.items(): + # shuffle image list for this specific breed (this changes the dict) + random.shuffle(breed_images) + breed_list = list(related_breeds_top.keys()) + random.shuffle(breed_list) + n_new_duos = 0 + for breed_name in breed_list: + if len(breed_image_dict[breed_name]) >= 2: + dog_a = breed_image_dict[breed_name].pop() + dog_b = breed_image_dict[breed_name].pop() + dog_duo = [dog_a, dog_b] + all_dog_duos.append({'image_names': dog_duo}) + # clean the related_breeds_top dict such that it only contains breeds which are still available + if len(breed_image_dict[breed_name]) < 1: + for breed_name_rel in list(related_breeds_top[breed_name].keys()): + related_breeds_top[breed_name_rel].pop(breed_name, None) + related_breeds_top[breed_name].pop(breed_name_rel, None) + n_new_duos += 1 + n_unused_images -= 2 + + image_name_to_ind = {} + for ind_img_name, img_name in enumerate(self.data_sampler_info['name_list']): + image_name_to_ind[img_name] = ind_img_name + + # take all images and create the batches + n_avail_2 = len(all_dog_duos) + all_batches = [] + ind_in_duos = 0 + n_imgs_used_twice = 0 + for ind_b in range(0, self.n_desired_batches): + batch_with_image_names = [] + for ind in range(int(np.floor(self.batch_size / 2))): + if ind_in_duos >= n_avail_2: + ind_rand = random.randint(0, n_avail_2-1) + batch_with_image_names.extend(all_dog_duos[ind_rand]['image_names']) + n_imgs_used_twice += 2 + else: + batch_with_image_names.extend(all_dog_duos[ind_in_duos]['image_names']) + ind_in_duos += 1 + + + batch_with_inds = [] + for image_name in batch_with_image_names: # rather a folder than name + batch_with_inds.append(image_name_to_ind[image_name]) + + all_batches.append(batch_with_inds) + + for batch in all_batches: + yield batch + + def __len__(self): + # Since we are sampling pairs of dogs and not each breed has an even number of dogs, we can not + # guarantee to show each dog exacly once. What we do instead, is returning the same amount of + # batches as we would return with a standard sampler which is not based on dog pairs. + '''if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' + return self.n_desired_batches + + + + + + + + diff --git a/src/stacked_hourglass/datasets/samplers/two_dataset_sampler.py b/src/stacked_hourglass/datasets/samplers/two_dataset_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6c38b49cc5acdc1f7c21d772866d14c6864bfa --- /dev/null +++ b/src/stacked_hourglass/datasets/samplers/two_dataset_sampler.py @@ -0,0 +1,103 @@ + +import numpy as np +import random +import copy +import time +import warnings + +from torch.utils.data import Sampler +from torch._six import int_classes as _int_classes +# from configs.dog_breeds.dog_breed_class import get_partial_summary + + + +class TwoDatasetSampler(Sampler): + """Wraps another sampler to yield a mini-batch of indices. + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, batch_size_half, size0, size1, shuffle=True, drop_last=True): + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if not isinstance(batch_size_half, _int_classes) or isinstance(batch_size_half, bool) or \ + batch_size_half <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size_half*2)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + assert size0 >= size1 + self.batch_size_half = batch_size_half + self.size0 = size0 + self.size1 = size1 + self.shuffle = shuffle + self.n_batches = self.size1//batch_size_half + self.drop_last = drop_last + + def get_description(self): + description = "\ + This sampler samples equally from two different datasets" + return description + + + def __iter__(self): + + dataset0 = np.arange(self.size0) + dataset1_init = np.arange(self.size1) + self.size0 + if self.shuffle: + np.random.shuffle(dataset0) + + dataset1 = [] + for ind in range(self.size0 // self.size1 + 1): + dataset1_part = dataset1_init.copy() + if self.shuffle: + np.random.shuffle(dataset1_part) + dataset1.extend(dataset1_part) + dataset0 = dataset0[0:self.n_batches*self.batch_size_half] + dataset1 = dataset1[0:self.n_batches*self.batch_size_half] + + # import pdb; pdb.set_trace() + + for ind_batch in range(self.n_batches): + d0 = dataset0[ind_batch*self.batch_size_half:(ind_batch+1)*self.batch_size_half] + d1 = dataset1[ind_batch*self.batch_size_half:(ind_batch+1)*self.batch_size_half] + + batch = list(d0) + list(d1) + # print(len(batch)) + + yield batch + + + + + + + + def __len__(self): + # Can only be called if self.sampler has __len__ implemented + # We cannot enforce this condition, so we turn off typechecking for the + # implementation below. + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + '''if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' + return self.n_batches + + + + + + + + diff --git a/src/stacked_hourglass/datasets/sketchfab.py b/src/stacked_hourglass/datasets/sketchfab.py new file mode 100644 index 0000000000000000000000000000000000000000..2f66735003f11f7d27feb3d1374fdc1b3c3072f4 --- /dev/null +++ b/src/stacked_hourglass/datasets/sketchfab.py @@ -0,0 +1,312 @@ + + +import os +import glob +import csv +import numpy as np +import cv2 +import math +import glob +import pickle as pkl +import open3d as o3d +import trimesh +import torch +import torch.utils.data as data + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.anipose_data_info import COMPLETE_DATA_INFO +from stacked_hourglass.utils.imutils import load_image +from stacked_hourglass.utils.transforms import crop, color_normalize +from stacked_hourglass.utils.pilutil import imresize +from stacked_hourglass.utils.imutils import im_to_torch +from configs.dataset_path_configs import TEST_IMAGE_CROP_ROOT_DIR +from configs.data_info import COMPLETE_DATA_INFO_24 + + +class SketchfabScans(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, img_crop_folder='default', image_path=None, is_train=False, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only'): + assert is_train == False + assert do_augment == 'default' or do_augment == False + self.inp_res = inp_res + + self.n_pcpoints = 3000 + self.folder_imgs = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'datasets', 'sketchfab_test_set', 'images') + self.folder_silh = self.folder_imgs.replace('images', 'silhouettes') + self.folder_point_clouds = self.folder_imgs.replace('images', 'point_clouds_' + str(self.n_pcpoints)) + self.folder_meshes = self.folder_imgs.replace('images', 'meshes') + self.csv_keyp_annots_path = self.folder_imgs.replace('images', 'keypoint_annotations/sketchfab_joint_annotations_complete.csv') + self.pkl_keyp_annots_path = self.folder_imgs.replace('images', 'keypoint_annotations/sketchfab_joint_annotations_complete_but_as_pkl_file.pkl') + self.all_mesh_paths = glob.glob(self.folder_meshes + '/**/*.obj', recursive=True) + name_list = glob.glob(os.path.join(self.folder_imgs, '*.png')) + glob.glob(os.path.join(self.folder_imgs, '*.jpg')) + glob.glob(os.path.join(self.folder_imgs, '*.jpeg')) + name_list = sorted(name_list) + # self.test_name_list = [name.split('/')[-1] for name in name_list] + self.test_name_list = [] + for name in name_list: + # if not (('13' in name) or ('dalmatian' in name and '1281' in name)): + # if not ('13' in name): + self.test_name_list.append(name.split('/')[-1]) + + + print('len(dataset): ' + str(self.__len__())) + + ''' + self.test_mesh_path_list = [] + for img_name in self.test_name_list: + breed = img_name.split('_')[0] # will be french instead of french_bulldog + mask = img_name.split('_')[-2] + this_mp = [] + for mp in self.all_mesh_paths: + if (breed in mp) and (mask in mp): + this_mp.append(mp) + if breed in 'french_bulldog': + this_mp_old = this_mp.copy() + this_mp = [] + for mp in this_mp_old: + if ('_' + mask + '.') in mp: + this_mp.append(mp) + if not len(this_mp) == 1: + print(breed) + print(mask) + this_mp[0].index(mask) + import pdb; pdb.set_trace() + else: + self.test_mesh_path_list.append(this_mp[0]) + + all_pc_paths = [] + for index in range(len(self.test_name_list)): + img_name = self.test_name_list[index] + dog_name = img_name.split('_' + img_name.split('_')[-1])[0] + breed = img_name.split('_')[0] # will be french instead of french_bulldog + mask = img_name.split('_')[-2] + path_pc = self.folder_point_clouds + '/' + dog_name + '.ply' + if not path_pc in all_pc_paths: + try: + print(path_pc) + mesh_path = self.test_mesh_path_list[index] + mesh_gt = o3d.io.read_triangle_mesh(mesh_path) + n_points = 3000 # 20000 + pointcloud = mesh_gt.sample_points_uniformly(number_of_points=n_points) + o3d.io.write_point_cloud(path_pc, pointcloud, write_ascii=False, compressed=False, print_progress=False) + all_pc_paths.append(path_pc) + except: + print(path_pc) + ''' + + # import pdb; pdb.set_trace() + + self.test_mesh_path_list = [] + self.all_pc_paths = [] + for index in range(len(self.test_name_list)): + img_name = self.test_name_list[index] + dog_name = img_name.split('_' + img_name.split('_')[-1])[0] + breed = img_name.split('_')[0] # will be french instead of french_bulldog + mask = img_name.split('_')[-2] + mesh_path = self.folder_meshes + '/' + dog_name + '.obj' + path_pc = self.folder_point_clouds + '/' + dog_name + '.ply' + if dog_name in ['dalmatian_1281', 'french_bulldog_13']: + # mesh_path_for_pc = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/datasets/sketchfab_test_set/meshes_old/dalmatian/1281/Renderbot-animal-obj-1281.obj' + mesh_path_for_pc = self.folder_meshes + '/' + dog_name + '_simple.obj' + else: + mesh_path_for_pc = mesh_path + self.test_mesh_path_list.append(mesh_path) + # if not path_pc in self.all_pc_paths: + if os.path.isfile(path_pc): + self.all_pc_paths.append(path_pc) + else: + try: + mesh_gt = o3d.io.read_triangle_mesh(mesh_path_for_pc) + except: + import pdb; pdb.set_trace() + mesh = trimesh.load(mesh_path_for_pc, process=False, maintain_order=True) + vertices = mesh.vertices + faces = mesh.faces + + print(mesh_path_for_pc) + pointcloud = mesh_gt.sample_points_uniformly(number_of_points=self.n_pcpoints) + o3d.io.write_point_cloud(path_pc, pointcloud, write_ascii=False, compressed=False, print_progress=False) + self.all_pc_paths.append(path_pc) + # except: + # print(path_pc) + + # add keypoint annotations (mesh vertices) + read_annots_from_csv = False # True + if read_annots_from_csv: + self.all_keypoint_annotations, self.keypoint_name_dict = self._read_keypoint_csv(self.csv_keyp_annots_path, folder_meshes=self.folder_meshes, get_keyp_coords=True) + with open(self.pkl_keyp_annots_path, 'wb') as handle: + pkl.dump(self.all_keypoint_annotations, handle, protocol=pkl.HIGHEST_PROTOCOL) + else: + with open(self.pkl_keyp_annots_path, 'rb') as handle: + self.all_keypoint_annotations = pkl.load(handle) + + + + + + def _read_keypoint_csv(self, csv_path, folder_meshes=None, get_keyp_coords=True, visualize=False): + with open(csv_path,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + assert(headers[2] == 'hiwi') + keypoint_names = headers[3:] + center_keypoint_names = ['nose','tail_start','tail_end'] + right_keypoint_names = ['right_front_paw','right_front_elbow','right_back_paw','right_back_hock','right_ear_top','right_ear_bottom','right_eye'] + left_keypoint_names = ['left_front_paw','left_front_elbow','left_back_paw','left_back_hock','left_ear_top','left_ear_bottom','left_eye'] + keypoint_name_dict = {'all': keypoint_names, 'left': left_keypoint_names, 'right': right_keypoint_names, 'center': center_keypoint_names} + # prepare output dicts + all_keypoint_annotations = {} + for ind in range(len(row_list)): + name = row_list[ind]['mesh_name'] + this_dict = row_list[ind] + del this_dict['hiwi'] + all_keypoint_annotations[name] = this_dict + keypoint_idxs = np.zeros((len(keypoint_names), 2)) + if get_keyp_coords: + mesh_path = folder_meshes + '/' + row_list[ind]['mesh_name'] + mesh = trimesh.load(mesh_path, process=False, maintain_order=True) + vertices = mesh.vertices + keypoint_3d_locations = np.zeros((len(keypoint_names), 4)) # 1, 2, 3: coords, 4: is_valid + for ind_kp, name_kp in enumerate(keypoint_names): + idx = this_dict[name_kp] + if idx in ['', 'n/a']: + keypoint_idxs[ind_kp, 0] = -1 + else: + keypoint_idxs[ind_kp, 0] = this_dict[name_kp] + keypoint_idxs[ind_kp, 1] = 1 # is valid + if get_keyp_coords: + keyp = vertices[int(row_list[ind][name_kp])] + keypoint_3d_locations[ind_kp, :3] = keyp + keypoint_3d_locations[ind_kp, 3] = 1 + all_keypoint_annotations[name]['all_keypoint_vertex_idxs'] = keypoint_idxs + if get_keyp_coords: + all_keypoint_annotations[name]['all_keypoint_coords_and_isvalid'] = keypoint_3d_locations + # create visualizations if desired + if visualize: + raise NotImplementedError # only debug path is missing + out_path = '.... some debug path' + red_color = np.asarray([255, 0, 0], dtype=np.uint8) + green_color = np.asarray([0, 255, 0], dtype=np.uint8) + blue_color = np.asarray([0, 0, 255], dtype=np.uint8) + for ind in range(len(row_list)): + mesh_path = folder_meshes + '/' + row_list[ind]['mesh_name'] + mesh = trimesh.load(mesh_path, process=False, maintain_order=True) # maintain_order is very important!!!!! + vertices = mesh.vertices + faces = mesh.faces + dog_mesh_nocolor = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, maintain_order=True) + dog_mesh_nocolor.visual.vertex_colors = np.ones_like(vertices, dtype=np.uint8) * 255 + sphere_list = [dog_mesh_nocolor] + for keyp_name in keypoint_names: + if not (row_list[ind][keyp_name] == '' or row_list[ind][keyp_name] == 'n/a'): + keyp = vertices[int(row_list[ind][keyp_name])] + sphere = trimesh.primitives.Sphere(radius=0.02, center=keyp) + if keyp_name in right_keypoint_names: + colors = np.ones_like(sphere.vertices) * red_color[None, :] + elif keyp_name in left_keypoint_names: + colors = np.ones_like(sphere.vertices) * blue_color[None, :] + else: + colors = np.ones_like(sphere.vertices) * green_color[None, :] + sphere.visual.vertex_colors = colors # trimesh.visual.random_color() + sphere_list.append(sphere) + scene_keyp = trimesh.Scene(sphere_list) + scene_keyp.export(out_path + os.path.basename(mesh_path).replace('.obj', '_withkeyp.obj')) + return all_keypoint_annotations, keypoint_name_dict + + + + def __getitem__(self, index): + img_name = self.test_name_list[index] + dog_name = img_name.split('_' + img_name.split('_')[-1])[0] + breed = img_name.split('_')[0] # will be french instead of french_bulldog + mask = img_name.split('_')[-2] + mesh_path = self.test_mesh_path_list[index] + # mesh_gt = o3d.io.read_triangle_mesh(mesh_path) + + path_pc = self.folder_point_clouds + '/' + dog_name + '.ply' + assert path_pc in self.all_pc_paths + pc_trimesh = trimesh.load(path_pc, process=False, maintain_order=True) + pc_points = np.asarray(pc_trimesh.vertices) + assert pc_points.shape[0] == self.n_pcpoints + + + # get annotated 3d keypoints + keyp_3d = self.all_keypoint_annotations[mesh_path.split('/')[-1]]['all_keypoint_coords_and_isvalid'] + + + # load image + img_path = os.path.join(self.folder_imgs, img_name) + + img = load_image(img_path) # CxHxW + # try on silhouette images! + # seg_path = os.path.join(self.folder_silh, img_name) + # img = load_image(seg_path) # CxHxW + + img_vis = np.transpose(img, (1, 2, 0)) + seg_path = os.path.join(self.folder_silh, img_name) + seg = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED)[:, :, 3] + seg[seg>0] = 1 + seg_s0 = np.nonzero(seg.sum(axis=1)>0)[0] + seg_s1 = np.nonzero(seg.sum(axis=0)>0)[0] + bbox_xywh = [seg_s1.min(), seg_s0.min(), seg_s1.max() - seg_s1.min(), seg_s0.max() - seg_s0.min()] + bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] + bbox_max = max(bbox_xywh[2], bbox_xywh[3]) + bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) + # bbox_s = bbox_max / 200. # the dog will fill the image -> bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + r = 0 + + # Prepare image and groundtruth map + inp_col = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + inp = color_normalize(inp_col, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + + silh_3channels = np.stack((seg, seg, seg), axis=0) + inp_silh = crop(silh_3channels, c, s, [self.inp_res, self.inp_res], rot=r) + + ''' + # prepare image (cropping and color) + img_max = max(img.shape[1], img.shape[2]) + img_padded = torch.zeros((img.shape[0], img_max, img_max)) + if img_max == img.shape[2]: + start = (img_max-img.shape[1])//2 + img_padded[:, start:start+img.shape[1], :] = img + else: + start = (img_max-img.shape[2])//2 + img_padded[:, :, start:start+img.shape[2]] = img + img = img_padded + img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear')) + inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + ''' + # add the following fields to make it compatible with stanext, most of them are fake + target_dict = {'index': index, 'center' : -2, 'scale' : -2, + 'breed_index': -2, 'sim_breed_index': -2, + 'ind_dataset': 1} + target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) + target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1)) + target_dict['silh'] = inp_silh[0, :, :] # np.zeros((self.inp_res, self.inp_res)) + target_dict['mesh_path'] = mesh_path + target_dict['pointcloud_path'] = path_pc + target_dict['pointcloud_points'] = pc_points + target_dict['keypoints_3d'] = keyp_3d + return inp, target_dict + + + def __len__(self): + return len(self.test_name_list) + + + + + + + + + diff --git a/src/stacked_hourglass/datasets/stanext24.py b/src/stacked_hourglass/datasets/stanext24.py new file mode 100644 index 0000000000000000000000000000000000000000..09bf9f13f88891f9bbbabe8779f8e4b4ff60abe2 --- /dev/null +++ b/src/stacked_hourglass/datasets/stanext24.py @@ -0,0 +1,403 @@ +# 24 joints instead of 20!! + + +import gzip +import json +import os +import random +import math +import numpy as np +import torch +import torch.utils.data as data +from importlib_resources import open_binary +from scipy.io import loadmat +from tabulate import tabulate +import itertools +import json +from scipy import ndimage + +from csv import DictReader +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.data_info import COMPLETE_DATA_INFO_24 +from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps +from stacked_hourglass.utils.misc import to_torch +from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform +import stacked_hourglass.datasets.utils_stanext as utils_stanext +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints +from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES +from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR + + +class StanExt(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + + # Suggested joints to use for keypoint reprojection error calculations + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'): + self.V12 = V12 + self.is_train = is_train # training set or test set + if do_augment == 'yes': + self.do_augment = True + elif do_augment == 'no': + self.do_augment = False + elif do_augment=='default': + if self.is_train: + self.do_augment = True + else: + self.do_augment = False + else: + raise ValueError + self.inp_res = inp_res + self.out_res = out_res + self.sigma = sigma + self.scale_factor = scale_factor + self.rot_factor = rot_factor + self.label_type = label_type + self.dataset_mode = dataset_mode + if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': + self.calc_seg = True + else: + self.calc_seg = False + self.val_opt = val_opt + + # create train/val split + self.img_folder = utils_stanext.get_img_dir(V12=self.V12) + self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + self.train_name_list = list(self.train_dict.keys()) # 7004 + if self.val_opt == 'test': + self.test_dict = init_test_dict + self.test_name_list = list(self.test_dict.keys()) + elif self.val_opt == 'val': + self.test_dict = init_val_dict + self.test_name_list = list(self.test_dict.keys()) + else: + raise NotImplementedError + + # stanext breed dict (contains for each name a stanext specific index) + breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') + self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) + self.train_name_list = sorted(self.train_name_list) + self.test_name_list = sorted(self.test_name_list) + random.seed(4) + random.shuffle(self.train_name_list) + random.shuffle(self.test_name_list) + if shorten_dataset_to is not None: + # sometimes it is useful to have a smaller set (validation speed, debugging) + self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] + self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] + # special case for debugging: 12 similar images + if shorten_dataset_to == 12: + my_sample = self.test_name_list[2] + for ind in range(0, 12): + self.test_name_list[ind] = my_sample + print('len(dataset): ' + str(self.__len__())) + + # add results for eyes, whithers and throat as obtained through anipose -> they are used + # as pseudo ground truth at training time. + # self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') + self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v1_results_on_StanExt') # this is from hg_anipose_after01bugfix_v1 + # self.prepare_anipose_res_and_save() + + + def get_data_sampler_info(self): + # for custom data sampler + if self.is_train: + name_list = self.train_name_list + else: + name_list = self.test_name_list + info_dict = {'name_list': name_list, + 'stanext_breed_dict': self.breed_dict, + 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, + 'breeds_summary': COMPLETE_SUMMARY_BREEDS, + 'breeds_sim_martix_raw': SIM_MATRIX_RAW, + 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES + } + return info_dict + + + def get_breed_dict(self, breed_json_path, create_new_breed_json=False): + if create_new_breed_json: + breed_dict = {} + breed_index = 0 + for img_name in self.train_name_list: + folder_name = img_name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + if not (folder_name in breed_dict): + breed_dict[folder_name] = { + 'breed_name': breed_name, + 'index': breed_index} + breed_index += 1 + with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) + else: + with open(breed_json_path) as json_file: breed_dict = json.load(json_file) + return breed_dict + + + + def prepare_anipose_res_and_save(self): + # I only had to run this once ... + # path_animalpose_res_root = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/results/animalpose_hg8_v0/' + path_animalpose_res_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' + + train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + train_name_list = list(train_dict.keys()) + val_name_list = list(init_val_dict.keys()) + test_name_list = list(init_test_dict.keys()) + all_dicts = [train_dict, init_val_dict, init_test_dict] + all_name_lists = [train_name_list, val_name_list, test_name_list] + all_prefixes = ['train', 'val', 'test'] + for ind in range(3): + this_name_list = all_name_lists[ind] + this_dict = all_dicts[ind] + this_prefix = all_prefixes[ind] + + for index in range(0, len(this_name_list)): + print(index) + name = this_name_list[index] + data = this_dict[name] + + img_path = os.path.join(self.img_folder, data['img_path']) + + path_animalpose_res = os.path.join(path_animalpose_res_root.replace('XXX', this_prefix), data['img_path'].replace('.jpg', '.json')) + + + # prepare predicted keypoints + '''if is_train: + path_animalpose_res = os.path.join(path_animalpose_res_root, 'train_stanext', 'res_' + str(index) + '.json') + else: + path_animalpose_res = os.path.join(path_animalpose_res_root, 'test_stanext', 'res_' + str(index) + '.json') + ''' + with open(path_animalpose_res) as f: animalpose_data = json.load(f) + anipose_joints_256 = np.asarray(animalpose_data['pred_joints_256']).reshape((-1, 3)) + anipose_center = animalpose_data['center'] + anipose_scale = animalpose_data['scale'] + anipose_joints_64 = anipose_joints_256 / 4 + '''thrs_21to24 = 0.2 + anipose_joints_21to24 = np.zeros((4, 3))) + for ind_j in range(0:4): + anipose_joints_untrans = transform(anipose_joints_64[20+ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_trans_again = transform(anipose_joints_untrans+1, anipose_center, anipose_scale, [64, 64], invert=False, rot=0, as_int=False) + anipose_joints_21to24[ind_j, :2] = anipose_joints_untrans + if anipose_joints_256[20+ind_j, 2] >= thrs_21to24: + anipose_joints_21to24[ind_j, 2] = 1''' + anipose_joints_0to24 = np.zeros((24, 3)) + for ind_j in range(24): + # anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2]+1, anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_0to24[ind_j, :2] = anipose_joints_untrans + anipose_joints_0to24[ind_j, 2] = anipose_joints_256[ind_j, 2] + # save anipose result for usage later on + out_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + if not os.path.exists(os.path.dirname(out_path)): os.makedirs(os.path.dirname(out_path)) + out_dict = {'orig_anipose_joints_256': list(anipose_joints_256.reshape((-1))), + 'anipose_joints_0to24': list(anipose_joints_0to24[:, :3].reshape((-1))), + 'orig_index': index, + 'orig_scale': animalpose_data['scale'], + 'orig_center': animalpose_data['center'], + 'data_split': this_prefix, # 'is_train': is_train, + } + with open(out_path, 'w') as outfile: json.dump(out_dict, outfile) + return + + + + + + + + + + + + + + + + + def __getitem__(self, index): + + if self.is_train: + train_val_test_Prefix = 'train' + name = self.train_name_list[index] + data = self.train_dict[name] + else: + train_val_test_Prefix = self.val_opt # 'val' or 'test' + name = self.test_name_list[index] + data = self.test_dict[name] + + + + sf = self.scale_factor + rf = self.rot_factor + + img_path = os.path.join(self.img_folder, data['img_path']) + try: + # import pdb; pdb.set_trace() + + '''new_anipose_root_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' + adjusted_new_anipose_root_path = new_anipose_root_path.replace('XXX', train_val_test_Prefix) + new_anipose_res_path = adjusted_new_anipose_root_path + data['img_path'].replace('.jpg', '.json') + with open(new_anipose_res_path) as f: new_anipose_data = json.load(f) + ''' + + anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + with open(anipose_res_path) as f: anipose_data = json.load(f) + anipose_thr = 0.2 + anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) + anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] + # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 + anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + + # For single-person pose estimation with a centered/scaled figure + nparts = pts.size(0) + img = load_image(img_path) # CxHxW + + # segmentation map (we reshape it to 3xHxW, such that we can do the + # same transformations as with the image) + if self.calc_seg: + seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) + seg = torch.cat(3*[seg]) + + r = 0 + do_flip = False + if self.do_augment: + s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] + r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 + # Flip + if random.random() <= 0.5: + do_flip = True + img = fliplr(img) + if self.calc_seg: + seg = fliplr(seg) + pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) + c[0] = img.size(2) - c[0] + # Color + img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + + # Prepare image and groundtruth map + inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground + inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + if self.calc_seg: + seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) + + # Generate ground truth + tpts = pts.clone() + target_weight = tpts[:, 2].clone().view(nparts, 1) + + target = torch.zeros(nparts, self.out_res, self.out_res) + for i in range(nparts): + # if tpts[i, 2] > 0: # This is evil!! + if tpts[i, 1] > 0: + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) - 1 + target[i], vis = draw_labelmap(target[i], tpts[i], self.sigma, type=self.label_type) + target_weight[i, 0] *= vis + # NEW: + '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) + target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new + target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' + + # --- Meta info + this_breed = self.breed_dict[name.split('/')[0]] # 120 + # add information about location within breed similarity matrix + folder_name = name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + abbrev = COMPLETE_ABBREV_DICT[breed_name] + try: + sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix + except: # some breeds are not in the xlsx file + sim_breed_index = -1 + meta = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, + 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 + meta2 = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'ind_dataset': 3} + + # import pdb; pdb.set_trace() + + # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/old_animalpose_version/' + # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/v0/' + # save_input_image_with_keypoints(inp, meta['tpts'], out_path = out_path_root + name.replace('/', '_'), ratio_in_out=self.inp_res/self.out_res) + + + # return different things depending on dataset_mode + if self.dataset_mode=='keyp_only': + # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg': + meta['silh'] = seg[0, :, :] + meta['name'] = name + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg_and_partseg': + # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations + meta2['silh'] = seg[0, :, :] + meta2['name'] = name + fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) + meta2['body_part_matrix'] = fake_body_part_matrix + return inp, target, meta2 + elif self.dataset_mode=='complete': + target_dict = meta + target_dict['silh'] = seg[0, :, :] + # NEW for silhouette loss + target_dict['img_border_mask'] = img_border_mask + target_dict['has_seg'] = True + if target_dict['silh'].sum() < 1: + if ((not self.is_train) and self.val_opt == 'test'): + raise ValueError + elif self.is_train: + print('had to replace training image') + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + else: + # There seem to be a few validation images without segmentation + # which would lead to nan in iou calculation + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + return inp, target_dict + else: + print('sampling error') + import pdb; pdb.set_trace() + raise ValueError + + + def __len__(self): + if self.is_train: + return len(self.train_name_list) + else: + return len(self.test_name_list) + + diff --git a/src/stacked_hourglass/datasets/stanext24_withgc.py b/src/stacked_hourglass/datasets/stanext24_withgc.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9118c1ff272e0044d8cc7bb94d21affc6ad882 --- /dev/null +++ b/src/stacked_hourglass/datasets/stanext24_withgc.py @@ -0,0 +1,561 @@ +# 24 joints instead of 20!! + + +import gzip +import json +import os +import random +import math +import numpy as np +import torch +import torch.utils.data as data +from importlib_resources import open_binary +from scipy.io import loadmat +from tabulate import tabulate +import itertools +import json +from scipy import ndimage +import csv +import pickle as pkl + +from csv import DictReader +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.data_info import COMPLETE_DATA_INFO_24 +from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps +from stacked_hourglass.utils.misc import to_torch +from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform +import stacked_hourglass.datasets.utils_stanext as utils_stanext +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints +from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES +from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR +from smal_pytorch.smal_model.smal_basics import get_symmetry_indices + + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + +class StanExtGC(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + + # Suggested joints to use for keypoint reprojection error calculations + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'): + self.V12 = V12 + self.is_train = is_train # training set or test set + if do_augment == 'yes': + self.do_augment = True + elif do_augment == 'no': + self.do_augment = False + elif do_augment=='default': + if self.is_train: + self.do_augment = True + else: + self.do_augment = False + else: + raise ValueError + self.inp_res = inp_res + self.out_res = out_res + self.sigma = sigma + self.scale_factor = scale_factor + self.rot_factor = rot_factor + self.label_type = label_type + self.dataset_mode = dataset_mode + if self.dataset_mode=='complete' or self.dataset_mode=='complete_with_gc' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': + self.calc_seg = True + else: + self.calc_seg = False + self.val_opt = val_opt + + # create train/val split + self.img_folder = utils_stanext.get_img_dir(V12=self.V12) + self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + self.train_name_list = list(self.train_dict.keys()) # 7004 + if self.val_opt == 'test': + self.test_dict = init_test_dict + self.test_name_list = list(self.test_dict.keys()) + elif self.val_opt == 'val': + self.test_dict = init_val_dict + self.test_name_list = list(self.test_dict.keys()) + else: + raise NotImplementedError + + + # path_gc_annots_overview = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/gc_annots_overview_first699.pkl' + path_gc_annots_overview = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/gc_annots_overview_stage3complete.pkl' + with open(path_gc_annots_overview, 'rb') as f: + self.gc_annots_overview = pkl.load(f) + list_gc_labelled_images = list(self.gc_annots_overview.keys()) + + test_name_list_gc = [] + for name in self.test_name_list: + if name.split('.')[0] in list_gc_labelled_images: + test_name_list_gc.append(name) + + train_name_list_gc = [] + for name in self.train_name_list: + if name.split('.')[0] in list_gc_labelled_images: + train_name_list_gc.append(name) + + self.test_name_list = test_name_list_gc + self.train_name_list = train_name_list_gc + + random.seed(4) + random.shuffle(self.test_name_list) + + ''' + already_labelled = ['n02093991-Irish_terrier/n02093991_2874.jpg', + 'n02093754-Border_terrier/n02093754_1062.jpg', + 'n02092339-Weimaraner/n02092339_1672.jpg', + 'n02096177-cairn/n02096177_4916.jpg', + 'n02110185-Siberian_husky/n02110185_725.jpg', + 'n02110806-basenji/n02110806_761.jpg', + 'n02094433-Yorkshire_terrier/n02094433_2474.jpg', + 'n02097474-Tibetan_terrier/n02097474_8796.jpg', + 'n02099601-golden_retriever/n02099601_2495.jpg'] + self.trainvaltest_dict = dict(self.train_dict) + for d in (init_test_dict, init_val_dict): self.trainvaltest_dict.update(d) + + gc_annot_csv = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/my_gcannotations_qualification.csv' + gc_row_list = read_csv(gc_annot_csv) + + json_acceptable_string = (gc_row_list[0]['vertices']).replace("'", "\"") + self.gc_dict = json.loads(json_acceptable_string) + + self.train_name_list = already_labelled + self.test_name_list = already_labelled + ''' + + + # stanext breed dict (contains for each name a stanext specific index) + breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') + self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) + + # load smal symmetry info + self.sym_ids_dict = get_symmetry_indices() + + ''' + self.train_name_list = sorted(self.train_name_list) + self.test_name_list = sorted(self.test_name_list) + random.seed(4) + random.shuffle(self.train_name_list) + random.shuffle(self.test_name_list) + if shorten_dataset_to is not None: + # sometimes it is useful to have a smaller set (validation speed, debugging) + self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] + self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] + # special case for debugging: 12 similar images + if shorten_dataset_to == 12: + my_sample = self.test_name_list[2] + for ind in range(0, 12): + self.test_name_list[ind] = my_sample + ''' + print('len(dataset): ' + str(self.__len__())) + + # add results for eyes, whithers and throat as obtained through anipose -> they are used + # as pseudo ground truth at training time. + # self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') + self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v1_results_on_StanExt') # this is from hg_anipose_after01bugfix_v1 + # self.prepare_anipose_res_and_save() + + + def get_data_sampler_info(self): + # for custom data sampler + if self.is_train: + name_list = self.train_name_list + else: + name_list = self.test_name_list + info_dict = {'name_list': name_list, + 'stanext_breed_dict': self.breed_dict, + 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, + 'breeds_summary': COMPLETE_SUMMARY_BREEDS, + 'breeds_sim_martix_raw': SIM_MATRIX_RAW, + 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES + } + return info_dict + + + def get_breed_dict(self, breed_json_path, create_new_breed_json=False): + if create_new_breed_json: + breed_dict = {} + breed_index = 0 + for img_name in self.train_name_list: + folder_name = img_name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + if not (folder_name in breed_dict): + breed_dict[folder_name] = { + 'breed_name': breed_name, + 'index': breed_index} + breed_index += 1 + with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) + else: + with open(breed_json_path) as json_file: breed_dict = json.load(json_file) + return breed_dict + + + + def prepare_anipose_res_and_save(self): + # I only had to run this once ... + # path_animalpose_res_root = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/results/animalpose_hg8_v0/' + path_animalpose_res_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' + + train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + train_name_list = list(train_dict.keys()) + val_name_list = list(init_val_dict.keys()) + test_name_list = list(init_test_dict.keys()) + all_dicts = [train_dict, init_val_dict, init_test_dict] + all_name_lists = [train_name_list, val_name_list, test_name_list] + all_prefixes = ['train', 'val', 'test'] + for ind in range(3): + this_name_list = all_name_lists[ind] + this_dict = all_dicts[ind] + this_prefix = all_prefixes[ind] + + for index in range(0, len(this_name_list)): + print(index) + name = this_name_list[index] + data = this_dict[name] + + img_path = os.path.join(self.img_folder, data['img_path']) + + path_animalpose_res = os.path.join(path_animalpose_res_root.replace('XXX', this_prefix), data['img_path'].replace('.jpg', '.json')) + + + # prepare predicted keypoints + '''if is_train: + path_animalpose_res = os.path.join(path_animalpose_res_root, 'train_stanext', 'res_' + str(index) + '.json') + else: + path_animalpose_res = os.path.join(path_animalpose_res_root, 'test_stanext', 'res_' + str(index) + '.json') + ''' + with open(path_animalpose_res) as f: animalpose_data = json.load(f) + anipose_joints_256 = np.asarray(animalpose_data['pred_joints_256']).reshape((-1, 3)) + anipose_center = animalpose_data['center'] + anipose_scale = animalpose_data['scale'] + anipose_joints_64 = anipose_joints_256 / 4 + '''thrs_21to24 = 0.2 + anipose_joints_21to24 = np.zeros((4, 3))) + for ind_j in range(0:4): + anipose_joints_untrans = transform(anipose_joints_64[20+ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_trans_again = transform(anipose_joints_untrans+1, anipose_center, anipose_scale, [64, 64], invert=False, rot=0, as_int=False) + anipose_joints_21to24[ind_j, :2] = anipose_joints_untrans + if anipose_joints_256[20+ind_j, 2] >= thrs_21to24: + anipose_joints_21to24[ind_j, 2] = 1''' + anipose_joints_0to24 = np.zeros((24, 3)) + for ind_j in range(24): + # anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2]+1, anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_0to24[ind_j, :2] = anipose_joints_untrans + anipose_joints_0to24[ind_j, 2] = anipose_joints_256[ind_j, 2] + # save anipose result for usage later on + out_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + if not os.path.exists(os.path.dirname(out_path)): os.makedirs(os.path.dirname(out_path)) + out_dict = {'orig_anipose_joints_256': list(anipose_joints_256.reshape((-1))), + 'anipose_joints_0to24': list(anipose_joints_0to24[:, :3].reshape((-1))), + 'orig_index': index, + 'orig_scale': animalpose_data['scale'], + 'orig_center': animalpose_data['center'], + 'data_split': this_prefix, # 'is_train': is_train, + } + with open(out_path, 'w') as outfile: json.dump(out_dict, outfile) + return + + + + + + + + + + + + + + + + + def __getitem__(self, index): + + + if self.is_train: + train_val_test_Prefix = 'train' + name = self.train_name_list[index] + data = self.train_dict[name] + else: + train_val_test_Prefix = self.val_opt # 'val' or 'test' + name = self.test_name_list[index] + data = self.test_dict[name] + img_path = os.path.join(self.img_folder, data['img_path']) + + + ''' + # for debugging only + train_val_test_Prefix = 'train' + name = self.train_name_list[index] + data = self.trainvaltest_dict[name] + img_path = os.path.join(self.img_folder, data['img_path']) + + if self.dataset_mode=='complete_with_gc': + n_verts_smal = 3889 + + gc_info_raw = self.gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact + gc_info = [] + gc_info_tch = torch.zeros((n_verts_smal)) + for ind_v in gc_info_raw: + if ind_v < n_verts_smal: + gc_info.append(ind_v) + gc_info_tch[ind_v] = 1 + gc_info_available = True + ''' + + # array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist] + gc_vertdists_overview = self.gc_annots_overview[name.split('.')[0]]['gc_vertdists_overview'] + + gc_info_tch = torch.tensor(gc_vertdists_overview[:, :]) # torch.tensor(gc_vertdists_overview[:, 0]) + gc_info_available = True + + + + + # import pdb; pdb.set_trace() + debugging = False + if debugging: + import shutil + import trimesh + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + vert_colors = np.repeat(255*gc_info_tch[:, 0].detach().cpu().numpy()[:, None], 3, 1) + # vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + debug_folder = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/gc_debugging/' + my_mesh.export(debug_folder + (name.split('/')[1]).replace('.jpg', '_withgc.obj')) + shutil.copy(img_path, debug_folder + name.split('/')[1]) + + + + + + sf = self.scale_factor + rf = self.rot_factor + try: + # import pdb; pdb.set_trace() + + '''new_anipose_root_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' + adjusted_new_anipose_root_path = new_anipose_root_path.replace('XXX', train_val_test_Prefix) + new_anipose_res_path = adjusted_new_anipose_root_path + data['img_path'].replace('.jpg', '.json') + with open(new_anipose_res_path) as f: new_anipose_data = json.load(f) + ''' + + anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + with open(anipose_res_path) as f: anipose_data = json.load(f) + anipose_thr = 0.2 + anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) + anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] + # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 + anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + + # For single-person pose estimation with a centered/scaled figure + nparts = pts.size(0) + img = load_image(img_path) # CxHxW + + # segmentation map (we reshape it to 3xHxW, such that we can do the + # same transformations as with the image) + if self.calc_seg: + seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) + seg = torch.cat(3*[seg]) + + r = 0 + do_flip = False + if self.do_augment: + s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] + r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 + # Flip + if random.random() <= 0.5: + do_flip = True + img = fliplr(img) + if self.calc_seg: + seg = fliplr(seg) + pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) + c[0] = img.size(2) - c[0] + # flip ground contact annotations + gc_info_tch_swapped = torch.zeros_like(gc_info_tch) + gc_info_tch_swapped[self.sym_ids_dict['center'], :] = gc_info_tch[self.sym_ids_dict['center'], :] + gc_info_tch_swapped[self.sym_ids_dict['right'], :] = gc_info_tch[self.sym_ids_dict['left'], :] + gc_info_tch_swapped[self.sym_ids_dict['left'], :] = gc_info_tch[self.sym_ids_dict['right'], :] + gc_info_tch = gc_info_tch_swapped + + + # Color + img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + + + + + # import pdb; pdb.set_trace() + debugging = False + if debugging and do_flip: + import shutil + import trimesh + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + vert_colors = np.repeat(255*gc_info_tch[:, 0].detach().cpu().numpy()[:, None], 3, 1) + # vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + debug_folder = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/gc_debugging/' + my_mesh.export(debug_folder + (name.split('/')[1]).replace('.jpg', '_withgc_flip.obj')) + + + + + + + + + + + + + + + # Prepare image and groundtruth map + inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground + inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + if self.calc_seg: + seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) + + # Generate ground truth + tpts = pts.clone() + target_weight = tpts[:, 2].clone().view(nparts, 1) + + target = torch.zeros(nparts, self.out_res, self.out_res) + for i in range(nparts): + # if tpts[i, 2] > 0: # This is evil!! + if tpts[i, 1] > 0: + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) - 1 + target[i], vis = draw_labelmap(target[i], tpts[i], self.sigma, type=self.label_type) + target_weight[i, 0] *= vis + # NEW: + '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) + target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new + target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' + + + # --- Meta info + this_breed = self.breed_dict[name.split('/')[0]] # 120 + # add information about location within breed similarity matrix + folder_name = name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + abbrev = COMPLETE_ABBREV_DICT[breed_name] + try: + sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix + except: # some breeds are not in the xlsx file + sim_breed_index = -1 + meta = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, + 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 + meta2 = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'ind_dataset': 3} + + # import pdb; pdb.set_trace() + + # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/old_animalpose_version/' + # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/v0/' + # save_input_image_with_keypoints(inp, meta['tpts'], out_path = out_path_root + name.replace('/', '_'), ratio_in_out=self.inp_res/self.out_res) + + + # return different things depending on dataset_mode + if self.dataset_mode=='keyp_only': + # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg': + meta['silh'] = seg[0, :, :] + meta['name'] = name + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg_and_partseg': + # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations + meta2['silh'] = seg[0, :, :] + meta2['name'] = name + fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) + meta2['body_part_matrix'] = fake_body_part_matrix + return inp, target, meta2 + elif (self.dataset_mode=='complete') or (self.dataset_mode=='complete_with_gc'): + target_dict = meta + target_dict['silh'] = seg[0, :, :] + # NEW for silhouette loss + target_dict['img_border_mask'] = img_border_mask + target_dict['has_seg'] = True + # ground contact + if self.dataset_mode=='complete_with_gc': + target_dict['has_gc_is_touching'] = True + target_dict['has_gc'] = gc_info_available + target_dict['gc'] = gc_info_tch + if target_dict['silh'].sum() < 1: + if ((not self.is_train) and self.val_opt == 'test'): + raise ValueError + elif self.is_train: + print('had to replace training image') + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + else: + # There seem to be a few validation images without segmentation + # which would lead to nan in iou calculation + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + return inp, target_dict + else: + print('sampling error') + import pdb; pdb.set_trace() + raise ValueError + + + def __len__(self): + if self.is_train: + return len(self.train_name_list) + else: + return len(self.test_name_list) + + diff --git a/src/stacked_hourglass/datasets/stanext24_withgc_v2.py b/src/stacked_hourglass/datasets/stanext24_withgc_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c63a649fad4d0a809ed68d5775b72643266ccbd9 --- /dev/null +++ b/src/stacked_hourglass/datasets/stanext24_withgc_v2.py @@ -0,0 +1,709 @@ +# this version includes all ground contact labeled data, not only the sitting/lying poses + + +import gzip +import json +import os +import random +import math +import numpy as np +import torch +import torch.utils.data as data +from importlib_resources import open_binary +from scipy.io import loadmat +from tabulate import tabulate +import itertools +import json +from scipy import ndimage +import csv +import pickle as pkl + +from csv import DictReader +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.data_info import COMPLETE_DATA_INFO_24 +from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps +from stacked_hourglass.utils.misc import to_torch +from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform +import stacked_hourglass.datasets.utils_stanext as utils_stanext +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints +from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES +from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR +from smal_pytorch.smal_model.smal_basics import get_symmetry_indices + + +def read_csv(csv_file): + with open(csv_file,'r') as f: + reader = csv.reader(f) + headers = next(reader) + row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader] + return row_list + +class StanExtGC(data.Dataset): + DATA_INFO = COMPLETE_DATA_INFO_24 + + # Suggested joints to use for keypoint reprojection error calculations + ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] + + def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1, + scale_factor=0.25, rot_factor=30, label_type='Gaussian', + do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test', add_nonflat=False): + self.V12 = V12 + self.is_train = is_train # training set or test set + if do_augment == 'yes': + self.do_augment = True + elif do_augment == 'no': + self.do_augment = False + elif do_augment=='default': + if self.is_train: + self.do_augment = True + else: + self.do_augment = False + else: + raise ValueError + self.inp_res = inp_res + self.out_res = out_res + self.sigma = sigma + self.scale_factor = scale_factor + self.rot_factor = rot_factor + self.label_type = label_type + self.dataset_mode = dataset_mode + self.add_nonflat = add_nonflat + if self.dataset_mode=='complete' or self.dataset_mode=='complete_with_gc' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg': + self.calc_seg = True + else: + self.calc_seg = False + self.val_opt = val_opt + + # create train/val split + self.img_folder = utils_stanext.get_img_dir(V12=self.V12) + self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + self.train_name_list = list(self.train_dict.keys()) # 7004 + if self.val_opt == 'test': + self.test_dict = init_test_dict + self.test_name_list = list(self.test_dict.keys()) + elif self.val_opt == 'val': + self.test_dict = init_val_dict + self.test_name_list = list(self.test_dict.keys()) + else: + raise NotImplementedError + + + # import pdb; pdb.set_trace() + + + # path_gc_annots_overview = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/gc_annots_overview_first699.pkl' + path_gc_annots_overview_stage3 = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/gc_annots_overview_stage3complete.pkl' + with open(path_gc_annots_overview_stage3, 'rb') as f: + self.gc_annots_overview_stage3 = pkl.load(f) # 2346 + + path_gc_annots_overview_stage2b_contact = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/gc_annots_overview_stage2b_contact_complete.pkl' + with open(path_gc_annots_overview_stage2b_contact, 'rb') as f: + self.gc_annots_overview_stage2b_contact = pkl.load(f) # 832 + + path_gc_annots_overview_stage2b_nocontact = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/gc_annots_overview_stage2b_nocontact_complete.pkl' + with open(path_gc_annots_overview_stage2b_nocontact, 'rb') as f: + self.gc_annots_overview_stage2b_nocontact = pkl.load(f) # 32 + + path_gc_annots_overview_stages12_all4pawsincontact = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/gc_annots_overview_all4pawsincontact.pkl' + with open(path_gc_annots_overview_stages12_all4pawsincontact, 'rb') as f: + self.gc_annots_overview_stages12_all4pawsincontact = pkl.load(f) # 1, symbolic only + + path_gc_annots_categories_stages12 = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/gc_annots_categories_stages12_complete.pkl' + with open(path_gc_annots_categories_stages12, 'rb') as f: + self.gc_annots_categories = pkl.load(f) # 12538 + + + test_name_list_gc = [] + for name in self.test_name_list: + if name in self.gc_annots_categories.keys(): + value = self.gc_annots_categories[name] + if (value['is_vis'] in [True, None]) and (value['is_flat'] in [True, None]) and (not value['pose'] == 'cantsee'): + test_name_list_gc.append(name) + + train_name_list_gc = [] + for name in self.train_name_list: + value = self.gc_annots_categories[name] + if (value['is_vis'] in [True, None]) and (value['is_flat'] in [True, None]) and (not value['pose'] == 'cantsee'): + train_name_list_gc.append(name) + + # import pdb; pdb.set_trace() + + + + + + '''self.gc_annots_overview = self.gc_annots_overview_stage3 + list_gc_labelled_images = list(self.gc_annots_overview.keys()) + + test_name_list_gc = [] + for name in self.test_name_list: + if name.split('.')[0] in list_gc_labelled_images: + test_name_list_gc.append(name) + + train_name_list_gc = [] + for name in self.train_name_list: + if name.split('.')[0] in list_gc_labelled_images: + train_name_list_gc.append(name)''' + + random.seed(4) + random.shuffle(test_name_list_gc) + + + + + # new: add images with non-flat ground in the end + # import pdb; pdb.set_trace() + if self.add_nonflat: + self.train_name_list_nonflat = [] + for name in self.train_name_list: + if name in self.gc_annots_categories.keys(): + value = self.gc_annots_categories[name] + if (value['is_vis'] in [True, None]) and (value['is_flat'] in [False]): + self.train_name_list_nonflat.append(name) + self.test_name_list_nonflat = [] + for name in self.test_name_list: + if name in self.gc_annots_categories.keys(): + value = self.gc_annots_categories[name] + if (value['is_vis'] in [True, None]) and (value['is_flat'] in [False]): + self.test_name_list_nonflat.append(name) + + + + + self.test_name_list = test_name_list_gc + self.train_name_list = train_name_list_gc + + + + + + ''' + already_labelled = ['n02093991-Irish_terrier/n02093991_2874.jpg', + 'n02093754-Border_terrier/n02093754_1062.jpg', + 'n02092339-Weimaraner/n02092339_1672.jpg', + 'n02096177-cairn/n02096177_4916.jpg', + 'n02110185-Siberian_husky/n02110185_725.jpg', + 'n02110806-basenji/n02110806_761.jpg', + 'n02094433-Yorkshire_terrier/n02094433_2474.jpg', + 'n02097474-Tibetan_terrier/n02097474_8796.jpg', + 'n02099601-golden_retriever/n02099601_2495.jpg'] + self.trainvaltest_dict = dict(self.train_dict) + for d in (init_test_dict, init_val_dict): self.trainvaltest_dict.update(d) + + gc_annot_csv = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/my_gcannotations_qualification.csv' + gc_row_list = read_csv(gc_annot_csv) + + json_acceptable_string = (gc_row_list[0]['vertices']).replace("'", "\"") + self.gc_dict = json.loads(json_acceptable_string) + + self.train_name_list = already_labelled + self.test_name_list = already_labelled + ''' + + + # stanext breed dict (contains for each name a stanext specific index) + breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json') + self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False) + + # load smal symmetry info + self.sym_ids_dict = get_symmetry_indices() + + ''' + self.train_name_list = sorted(self.train_name_list) + self.test_name_list = sorted(self.test_name_list) + random.seed(4) + random.shuffle(self.train_name_list) + random.shuffle(self.test_name_list) + if shorten_dataset_to is not None: + # sometimes it is useful to have a smaller set (validation speed, debugging) + self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)] + self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)] + # special case for debugging: 12 similar images + if shorten_dataset_to == 12: + my_sample = self.test_name_list[2] + for ind in range(0, 12): + self.test_name_list[ind] = my_sample + ''' + print('len(dataset): ' + str(self.__len__())) + + # add results for eyes, whithers and throat as obtained through anipose -> they are used + # as pseudo ground truth at training time. + # self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt') + self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v1_results_on_StanExt') # this is from hg_anipose_after01bugfix_v1 + # self.prepare_anipose_res_and_save() + + + def get_data_sampler_info(self): + # for custom data sampler + if self.is_train: + name_list = self.train_name_list + else: + name_list = self.test_name_list + info_dict = {'name_list': name_list, + 'stanext_breed_dict': self.breed_dict, + 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT, + 'breeds_summary': COMPLETE_SUMMARY_BREEDS, + 'breeds_sim_martix_raw': SIM_MATRIX_RAW, + 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES + } + return info_dict + + def get_data_sampler_info_gc(self): + # for custom data sampler + if self.is_train: + name_list = self.train_name_list + else: + name_list = self.test_name_list + info_dict_gc = {'name_list': name_list, + 'gc_annots_categories': self.gc_annots_categories, + } + if self.add_nonflat: + if self.is_train: + name_list_nonflat = self.train_name_list_nonflat + else: + name_list_nonflat = self.test_name_list_nonflat + info_dict_gc['name_list_nonflat'] = name_list_nonflat + return info_dict_gc + + + + def get_breed_dict(self, breed_json_path, create_new_breed_json=False): + if create_new_breed_json: + breed_dict = {} + breed_index = 0 + for img_name in self.train_name_list: + folder_name = img_name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + if not (folder_name in breed_dict): + breed_dict[folder_name] = { + 'breed_name': breed_name, + 'index': breed_index} + breed_index += 1 + with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4) + else: + with open(breed_json_path) as json_file: breed_dict = json.load(json_file) + return breed_dict + + + + def prepare_anipose_res_and_save(self): + # I only had to run this once ... + # path_animalpose_res_root = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/results/animalpose_hg8_v0/' + path_animalpose_res_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' + + train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12) + train_name_list = list(train_dict.keys()) + val_name_list = list(init_val_dict.keys()) + test_name_list = list(init_test_dict.keys()) + all_dicts = [train_dict, init_val_dict, init_test_dict] + all_name_lists = [train_name_list, val_name_list, test_name_list] + all_prefixes = ['train', 'val', 'test'] + for ind in range(3): + this_name_list = all_name_lists[ind] + this_dict = all_dicts[ind] + this_prefix = all_prefixes[ind] + + for index in range(0, len(this_name_list)): + print(index) + name = this_name_list[index] + data = this_dict[name] + + img_path = os.path.join(self.img_folder, data['img_path']) + + path_animalpose_res = os.path.join(path_animalpose_res_root.replace('XXX', this_prefix), data['img_path'].replace('.jpg', '.json')) + + + # prepare predicted keypoints + '''if is_train: + path_animalpose_res = os.path.join(path_animalpose_res_root, 'train_stanext', 'res_' + str(index) + '.json') + else: + path_animalpose_res = os.path.join(path_animalpose_res_root, 'test_stanext', 'res_' + str(index) + '.json') + ''' + with open(path_animalpose_res) as f: animalpose_data = json.load(f) + anipose_joints_256 = np.asarray(animalpose_data['pred_joints_256']).reshape((-1, 3)) + anipose_center = animalpose_data['center'] + anipose_scale = animalpose_data['scale'] + anipose_joints_64 = anipose_joints_256 / 4 + '''thrs_21to24 = 0.2 + anipose_joints_21to24 = np.zeros((4, 3))) + for ind_j in range(0:4): + anipose_joints_untrans = transform(anipose_joints_64[20+ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_trans_again = transform(anipose_joints_untrans+1, anipose_center, anipose_scale, [64, 64], invert=False, rot=0, as_int=False) + anipose_joints_21to24[ind_j, :2] = anipose_joints_untrans + if anipose_joints_256[20+ind_j, 2] >= thrs_21to24: + anipose_joints_21to24[ind_j, 2] = 1''' + anipose_joints_0to24 = np.zeros((24, 3)) + for ind_j in range(24): + # anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2], anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_untrans = transform(anipose_joints_64[ind_j, 0:2]+1, anipose_center, anipose_scale, [64, 64], invert=True, rot=0, as_int=False)-1 + anipose_joints_0to24[ind_j, :2] = anipose_joints_untrans + anipose_joints_0to24[ind_j, 2] = anipose_joints_256[ind_j, 2] + # save anipose result for usage later on + out_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + if not os.path.exists(os.path.dirname(out_path)): os.makedirs(os.path.dirname(out_path)) + out_dict = {'orig_anipose_joints_256': list(anipose_joints_256.reshape((-1))), + 'anipose_joints_0to24': list(anipose_joints_0to24[:, :3].reshape((-1))), + 'orig_index': index, + 'orig_scale': animalpose_data['scale'], + 'orig_center': animalpose_data['center'], + 'data_split': this_prefix, # 'is_train': is_train, + } + with open(out_path, 'w') as outfile: json.dump(out_dict, outfile) + return + + + + + + + + + + + + + + + + + def __getitem__(self, index): + + + if self.is_train: + train_val_test_Prefix = 'train' + if self.add_nonflat and index >= len(self.train_name_list): + name = self.train_name_list_nonflat[index - len(self.train_name_list)] + gc_isflat = 0 + else: + name = self.train_name_list[index] + gc_isflat = 1 + data = self.train_dict[name] + else: + train_val_test_Prefix = self.val_opt # 'val' or 'test' + if self.add_nonflat and index >= len(self.test_name_list): + name = self.test_name_list_nonflat[index - len(self.test_name_list)] + gc_isflat = 0 + else: + name = self.test_name_list[index] + gc_isflat = 1 + data = self.test_dict[name] + img_path = os.path.join(self.img_folder, data['img_path']) + + + ''' + # for debugging only + train_val_test_Prefix = 'train' + name = self.train_name_list[index] + data = self.trainvaltest_dict[name] + img_path = os.path.join(self.img_folder, data['img_path']) + + if self.dataset_mode=='complete_with_gc': + n_verts_smal = 3889 + + gc_info_raw = self.gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact + gc_info = [] + gc_info_tch = torch.zeros((n_verts_smal)) + for ind_v in gc_info_raw: + if ind_v < n_verts_smal: + gc_info.append(ind_v) + gc_info_tch[ind_v] = 1 + gc_info_available = True + ''' + + # array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist] + n_verts_smal = 3889 + if gc_isflat: + if name.split('.')[0] in self.gc_annots_overview_stage3: + gc_vertdists_overview = self.gc_annots_overview_stage3[name.split('.')[0]]['gc_vertdists_overview'] + gc_info_tch = torch.tensor(gc_vertdists_overview[:, :]) # torch.tensor(gc_vertdists_overview[:, 0]) + gc_info_available = True + gc_touching_ground = True + elif name.split('.')[0] in self.gc_annots_overview_stage2b_contact: + gc_vertdists_overview = self.gc_annots_overview_stage2b_contact[name.split('.')[0]]['gc_vertdists_overview'] + gc_info_tch = torch.tensor(gc_vertdists_overview[:, :]) # torch.tensor(gc_vertdists_overview[:, 0]) + gc_info_available = True + gc_touching_ground = True + elif name.split('.')[0] in self.gc_annots_overview_stage2b_nocontact: + gc_info_tch = torch.zeros((n_verts_smal, 3)) + gc_info_tch[:, 2] = 2.0 # big distance + gc_info_available = True + gc_touching_ground = False + else: + if 'pose' in self.gc_annots_categories[name]: + pose_label = self.gc_annots_categories[name]['pose'] + if pose_label in ['standing_4paws']: + gc_vertdists_overview = self.gc_annots_overview_stages12_all4pawsincontact['all4pawsincontact']['gc_vertdists_overview'] + gc_info_tch = torch.tensor(gc_vertdists_overview[:, :]) # torch.tensor(gc_vertdists_overview[:, 0]) + gc_info_available = True + gc_touching_ground = True + elif pose_label in ['jumping_nottouching']: + gc_info_tch = torch.zeros((n_verts_smal, 3)) + gc_info_tch[:, 2] = 2.0 # big distance + gc_info_available = True + gc_touching_ground = False + else: + gc_info_tch = torch.zeros((n_verts_smal, 3)) + gc_info_tch[:, 2] = 2.0 # big distance + gc_info_available = False + gc_touching_ground = False + else: + gc_info_tch = torch.zeros((n_verts_smal, 3)) + gc_info_tch[:, 2] = 2.0 # big distance + gc_info_available = False + gc_touching_ground = False + + + # is this pose approximatly symmetric? head pose is not considered + approximately_symmetric_pose = False + if 'pose' in self.gc_annots_categories[name]: + pose_label = self.gc_annots_categories[name]['pose'] + if pose_label in ['lying_sym', 'sitting_sym']: + approximately_symmetric_pose = True + + + + + # import pdb; pdb.set_trace() + debugging = False + if debugging: + import shutil + import trimesh + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + vert_colors = np.repeat(255*gc_info_tch[:, 0].detach().cpu().numpy()[:, None], 3, 1) + # vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + debug_folder = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/gc_debugging/' + my_mesh.export(debug_folder + (name.split('/')[1]).replace('.jpg', '_withgc.obj')) + shutil.copy(img_path, debug_folder + name.split('/')[1]) + + + + + + sf = self.scale_factor + rf = self.rot_factor + try: + # import pdb; pdb.set_trace() + + '''new_anipose_root_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results/hg_anipose_after01bugfix_v1/stanext24_XXX_e300_json/' + adjusted_new_anipose_root_path = new_anipose_root_path.replace('XXX', train_val_test_Prefix) + new_anipose_res_path = adjusted_new_anipose_root_path + data['img_path'].replace('.jpg', '.json') + with open(new_anipose_res_path) as f: new_anipose_data = json.load(f) + ''' + + anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json')) + with open(anipose_res_path) as f: anipose_data = json.load(f) + anipose_thr = 0.2 + anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3)) + anipose_joints_0to24_scores = anipose_joints_0to24[:, 2] + # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0 + anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256 + # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200 + bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 + c = torch.Tensor(bbox_c) + s = bbox_s + + # For single-person pose estimation with a centered/scaled figure + nparts = pts.size(0) + img = load_image(img_path) # CxHxW + + # segmentation map (we reshape it to 3xHxW, such that we can do the + # same transformations as with the image) + if self.calc_seg: + seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :]) + seg = torch.cat(3*[seg]) + + r = 0 + do_flip = False + if self.do_augment: + s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] + r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 + # Flip + if random.random() <= 0.5: + do_flip = True + img = fliplr(img) + if self.calc_seg: + seg = fliplr(seg) + pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices) + c[0] = img.size(2) - c[0] + # flip ground contact annotations + gc_info_tch_swapped = torch.zeros_like(gc_info_tch) + gc_info_tch_swapped[self.sym_ids_dict['center'], :] = gc_info_tch[self.sym_ids_dict['center'], :] + gc_info_tch_swapped[self.sym_ids_dict['right'], :] = gc_info_tch[self.sym_ids_dict['left'], :] + gc_info_tch_swapped[self.sym_ids_dict['left'], :] = gc_info_tch[self.sym_ids_dict['right'], :] + gc_info_tch = gc_info_tch_swapped + + + # Color + img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) + + + + + # import pdb; pdb.set_trace() + debugging = False + if debugging and do_flip: + import shutil + import trimesh + from smal_pytorch.smal_model.smal_torch_new import SMAL + smal = SMAL() + verts = smal.v_template.detach().cpu().numpy() + faces = smal.faces.detach().cpu().numpy() + vert_colors = np.repeat(255*gc_info_tch[:, 0].detach().cpu().numpy()[:, None], 3, 1) + # vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1) + my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True) + my_mesh.visual.vertex_colors = vert_colors + debug_folder = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/gc_debugging/' + my_mesh.export(debug_folder + (name.split('/')[1]).replace('.jpg', '_withgc_flip.obj')) + + + + + + + + + + + + + + + # Prepare image and groundtruth map + inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) + img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground + inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) + if self.calc_seg: + seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r) + + # Generate ground truth + tpts = pts.clone() + target_weight = tpts[:, 2].clone().view(nparts, 1) + + target = torch.zeros(nparts, self.out_res, self.out_res) + for i in range(nparts): + # if tpts[i, 2] > 0: # This is evil!! + if tpts[i, 1] > 0: + tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False)) - 1 + target[i], vis = draw_labelmap(target[i], tpts[i], self.sigma, type=self.label_type) + target_weight[i, 0] *= vis + # NEW: + '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type) + target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new + target_new[(target_weight_new==0).reshape((-1)), :, :] = 0''' + + + # --- Meta info + this_breed = self.breed_dict[name.split('/')[0]] # 120 + # add information about location within breed similarity matrix + folder_name = name.split('/')[0] + breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1] + abbrev = COMPLETE_ABBREV_DICT[breed_name] + try: + sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix + except: # some breeds are not in the xlsx file + sim_breed_index = -1 + meta = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index, + 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2 + meta2 = {'index' : index, 'center' : c, 'scale' : s, + 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight, + 'ind_dataset': 3} + + # import pdb; pdb.set_trace() + + # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/old_animalpose_version/' + # out_path_root = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/stanext_preprocessing/v0/' + # save_input_image_with_keypoints(inp, meta['tpts'], out_path = out_path_root + name.replace('/', '_'), ratio_in_out=self.inp_res/self.out_res) + + + # return different things depending on dataset_mode + if self.dataset_mode=='keyp_only': + # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res) + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg': + meta['silh'] = seg[0, :, :] + meta['name'] = name + return inp, target, meta + elif self.dataset_mode=='keyp_and_seg_and_partseg': + # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations + meta2['silh'] = seg[0, :, :] + meta2['name'] = name + fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1) + meta2['body_part_matrix'] = fake_body_part_matrix + return inp, target, meta2 + elif (self.dataset_mode=='complete') or (self.dataset_mode=='complete_with_gc'): + target_dict = meta + target_dict['silh'] = seg[0, :, :] + # NEW for silhouette loss + target_dict['img_border_mask'] = img_border_mask + target_dict['has_seg'] = True + # ground contact + if self.dataset_mode=='complete_with_gc': + target_dict['has_gc_is_touching'] = gc_touching_ground + target_dict['has_gc'] = gc_info_available + target_dict['gc'] = gc_info_tch + target_dict['approximately_symmetric_pose'] = approximately_symmetric_pose + target_dict['isflat'] = gc_isflat + if target_dict['silh'].sum() < 1: + if ((not self.is_train) and self.val_opt == 'test'): + raise ValueError + elif self.is_train: + print('had to replace training image') + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + else: + # There seem to be a few validation images without segmentation + # which would lead to nan in iou calculation + replacement_index = max(0, index - 1) + inp, target_dict = self.__getitem__(replacement_index) + return inp, target_dict + else: + print('sampling error') + import pdb; pdb.set_trace() + raise ValueError + + def get_len_nonflat(self): + if self.is_train: + return len(self.train_name_list_nonflat) + else: + return len(self.test_name_list_nonflat) + + + def __len__(self): + if self.is_train: + return len(self.train_name_list) + else: + return len(self.test_name_list) + + diff --git a/src/stacked_hourglass/datasets/utils_dataset_selection.py b/src/stacked_hourglass/datasets/utils_dataset_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..029a4a44d0d15d461cab57dfacbf60824336f86c --- /dev/null +++ b/src/stacked_hourglass/datasets/utils_dataset_selection.py @@ -0,0 +1,116 @@ + +import torch +from torch.utils.data import DataLoader +import cv2 +import glob +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) + + +def get_evaluation_dataset(cfg_data_dataset, cfg_data_val_opt, cfg_data_V12, cfg_optim_batch_size, args_workers, drop_last=False): + # cfg_data_dataset = cfg.data.DATASET + # cfg_data_val_opt = cfg.data.VAL_OPT + # cfg_data_V12 = cfg.data.V12 + # cfg_optim_batch_size = cfg.optim.BATCH_SIZE + # args_workers = args.workers + assert cfg_data_dataset in ['stanext24_easy', 'stanext24', 'stanext24_withgc', 'stanext24_withgc_big'] + assert cfg_data_val_opt in ['train', 'test', 'val'] + + if cfg_data_dataset == 'stanext24_easy': + from stacked_hourglass.datasets.stanext24_easy import StanExtEasy as StanExt + dataset_mode = 'complete' + elif cfg_data_dataset == 'stanext24': + from stacked_hourglass.datasets.stanext24 import StanExt + dataset_mode = 'complete' + elif cfg_data_dataset == 'stanext24_withgc': + from stacked_hourglass.datasets.stanext24_withgc import StanExtGC as StanExt + dataset_mode = 'complete_with_gc' + elif cfg_data_dataset == 'stanext24_withgc_big': + from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt + dataset_mode = 'complete_with_gc' + + # Initialise the validation set dataloader + if cfg_data_val_opt == 'test': + val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg_data_V12, val_opt='test') + test_name_list = val_dataset.test_name_list + elif cfg_data_val_opt == 'val': + val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg_data_V12, val_opt='val') + test_name_list = val_dataset.test_name_list + elif cfg_data_val_opt == 'train': + val_dataset = StanExt(image_path=None, is_train=True, do_augment='no', dataset_mode=dataset_mode, V12=cfg_data_V12) + test_name_list = val_dataset.train_name_list + else: + raise ValueError + val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False, + num_workers=args_workers, pin_memory=True, drop_last=drop_last) # False) # , drop_last=True args.batch_size + len_val_dataset = len(val_dataset) + stanext_data_info = StanExt.DATA_INFO + stanext_acc_joints = StanExt.ACC_JOINTS + return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints + + +def get_sketchfab_evaluation_dataset(cfg_optim_batch_size, args_workers): + # cfg_optim_batch_size = cfg.optim.BATCH_SIZE + # args_workers = args.workers + from stacked_hourglass.datasets.sketchfab import SketchfabScans + val_dataset = SketchfabScans(image_path=None, is_train=False, dataset_mode='complete') + test_name_list = val_dataset.test_name_list + val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False, + num_workers=args_workers, pin_memory=True, drop_last=False) # drop_last=True) + from stacked_hourglass.datasets.stanext24 import StanExt + len_val_dataset = len(val_dataset) + stanext_data_info = StanExt.DATA_INFO + stanext_acc_joints = StanExt.ACC_JOINTS + return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints + +def get_crop_evaluation_dataset(cfg_optim_batch_size, args_workers, input_folder): + from stacked_hourglass.datasets.imgcropslist import ImgCrops + image_list_paths = glob.glob(os.path.join(input_folder, '*.jpg')) + glob.glob(os.path.join(input_folder, '*.png')) + image_list = [] + test_name_list = [] + for image_path in image_list_paths: + test_name_list.append(os.path.basename(image_path).split('.')[0]) + img = cv2.imread(image_path) + image_list.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + val_dataset = ImgCrops(image_list=image_list, bbox_list=None) + val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False, + num_workers=args_workers, pin_memory=True, drop_last=False) # drop_last=True) + from stacked_hourglass.datasets.stanext24 import StanExt + len_val_dataset = len(val_dataset) + stanext_data_info = StanExt.DATA_INFO + stanext_acc_joints = StanExt.ACC_JOINTS + return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints + +def get_single_crop_dataset_from_image(input_image, bbox=None): + from stacked_hourglass.datasets.imgcropslist import ImgCrops + input_image_list = [input_image] + if bbox is not None: + input_bbox_list = [bbox] + else: + input_bbox_list = None + # prepare data loader + val_dataset = ImgCrops(image_list=input_image_list, bbox_list=input_bbox_list, dataset_mode='complete') + test_name_list = val_dataset.test_name_list + val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, + num_workers=0, pin_memory=True, drop_last=False) + from stacked_hourglass.datasets.stanext24 import StanExt + len_val_dataset = len(val_dataset) + stanext_data_info = StanExt.DATA_INFO + stanext_acc_joints = StanExt.ACC_JOINTS + return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints + + + + +def get_norm_dict(data_info=None, device="cuda"): + if data_info is None: + from stacked_hourglass.datasets.stanext24 import StanExt + data_info = StanExt.DATA_INFO + norm_dict = { + 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device), + 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device), + 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device), + 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device), + 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)} + return norm_dict \ No newline at end of file diff --git a/src/stacked_hourglass/datasets/utils_stanext.py b/src/stacked_hourglass/datasets/utils_stanext.py new file mode 100644 index 0000000000000000000000000000000000000000..83da8452f74ff8fb0ca95e2d8a42ba96972f684b --- /dev/null +++ b/src/stacked_hourglass/datasets/utils_stanext.py @@ -0,0 +1,114 @@ + +import os +from matplotlib import pyplot as plt +import glob +import json +import numpy as np +from scipy.io import loadmat +from csv import DictReader +from collections import OrderedDict +from pycocotools.mask import decode as decode_RLE + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) +from configs.dataset_path_configs import IMG_V12_DIR, JSON_V12_DIR, STAN_V12_TRAIN_LIST_DIR, STAN_V12_VAL_LIST_DIR, STAN_V12_TEST_LIST_DIR + + +def get_img_dir(V12): + if V12: + return IMG_V12_DIR + else: + return IMG_DIR + +def get_seg_from_entry(entry): + """Given a .json entry, returns the binary mask as a numpy array""" + rle = { + "size": [entry['img_height'], entry['img_width']], + "counts": entry['seg']} + decoded = decode_RLE(rle) + return decoded + +def full_animal_visible(seg_data): + if seg_data[0, :].sum() == 0 and seg_data[seg_data.shape[0]-1, :].sum() == 0 and seg_data[:, 0].sum() == 0 and seg_data[:, seg_data.shape[1]-1].sum() == 0: + return True + else: + return False + +def load_train_and_test_lists(train_list_dir=None , test_list_dir=None): + """ returns sets containing names such as 'n02085620-Chihuahua/n02085620_5927.jpg' """ + # train data + train_list_mat = loadmat(train_list_dir) + train_list = [] + for ind in range(0, train_list_mat['file_list'].shape[0]): + name = train_list_mat['file_list'][ind, 0][0] + train_list.append(name) + # test data + test_list_mat = loadmat(test_list_dir) + test_list = [] + for ind in range(0, test_list_mat['file_list'].shape[0]): + name = test_list_mat['file_list'][ind, 0][0] + test_list.append(name) + return train_list, test_list + + + +def _filter_dict(t_list, j_dict, n_kp_min=4): + """ should only be used by load_stanext_json_as_dict() """ + out_dict = {} + for sample in t_list: + if sample in j_dict.keys(): + n_kp = np.asarray(j_dict[sample]['joints'])[:, 2].sum() + if n_kp >= n_kp_min: + out_dict[sample] = j_dict[sample] + return out_dict + +def load_stanext_json_as_dict(split_train_test=True, V12=True): + # load json into memory + if V12: + with open(JSON_V12_DIR) as infile: + json_data = json.load(infile) + # with open(JSON_V12_DIR) as infile: json_data = json.load(infile, object_pairs_hook=OrderedDict) + else: + with open(JSON_DIR) as infile: + json_data = json.load(infile) + # convert json data to a dictionary of img_path : all_data, for easy lookup + json_dict = {i['img_path']: i for i in json_data} + if split_train_test: + if V12: + train_list_numbers = np.load(STAN_V12_TRAIN_LIST_DIR) + val_list_numbers = np.load(STAN_V12_VAL_LIST_DIR) + test_list_numbers = np.load(STAN_V12_TEST_LIST_DIR) + train_list = [json_data[i]['img_path'] for i in train_list_numbers] + val_list = [json_data[i]['img_path'] for i in val_list_numbers] + test_list = [json_data[i]['img_path'] for i in test_list_numbers] + train_dict = _filter_dict(train_list, json_dict, n_kp_min=4) + val_dict = _filter_dict(val_list, json_dict, n_kp_min=4) + test_dict = _filter_dict(test_list, json_dict, n_kp_min=4) + return train_dict, test_dict, val_dict + else: + train_list, test_list = load_train_and_test_lists(train_list_dir=STAN_ORIG_TRAIN_LIST_DIR , test_list_dir=STAN_ORIG_TEST_LIST_DIR) + train_dict = _filter_dict(train_list, json_dict) + test_dict = _filter_dict(test_list, json_dict) + return train_dict, test_dict, None + else: + return json_dict + +def get_dog(json_dict, name, img_dir=None): # (json_dict, name, img_dir=IMG_DIR) + """ takes the name of a dog, and loads in all the relevant information as a dictionary: + dict_keys(['img_path', 'img_width', 'img_height', 'joints', 'img_bbox', + 'is_multiple_dogs', 'seg', 'img_data', 'seg_data']) + img_bbox: [x0, y0, width, height] """ + data = json_dict[name] + # load img + img_data = plt.imread(os.path.join(img_dir, data['img_path'])) + # load seg + seg_data = get_seg_from_entry(data) + # add to output + data['img_data'] = img_data # 0 to 255 + data['seg_data'] = seg_data # 0: bg, 1: fg + return data + + + + + diff --git a/src/stacked_hourglass/loss.py b/src/stacked_hourglass/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8b452efc84fd9bc58d619e637ad84a9f108e6f81 --- /dev/null +++ b/src/stacked_hourglass/loss.py @@ -0,0 +1,161 @@ +import torch.nn as nn +import torch +from torch.nn.functional import mse_loss +# for NEW: losses when calculated on keypoint locations +# see https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/subpix/dsnt.html +# from kornia.geometry import dsnt # old kornia version +from kornia.geometry.subpix import dsnt # kornia 0.4.0 + +def joints_mse_loss_orig(output, target, target_weight=None): + batch_size = output.size(0) + num_joints = output.size(1) + heatmaps_pred = output.view((batch_size, num_joints, -1)).split(1, 1) + heatmaps_gt = target.view((batch_size, num_joints, -1)).split(1, 1) + + loss = 0 + for idx in range(num_joints): + heatmap_pred = heatmaps_pred[idx] + heatmap_gt = heatmaps_gt[idx] + if target_weight is None: + loss += 0.5 * mse_loss(heatmap_pred, heatmap_gt, reduction='mean') + else: + loss += 0.5 * mse_loss( + heatmap_pred.mul(target_weight[:, idx]), + heatmap_gt.mul(target_weight[:, idx]), + reduction='mean' + ) + + return loss / num_joints + + +class JointsMSELoss(nn.Module): + def __init__(self, use_target_weight=True): + super().__init__() + self.use_target_weight = use_target_weight + raise NotImplementedError + + def forward(self, output, target, target_weight): + if not self.use_target_weight: + target_weight = None + return joints_mse_loss_orig(output, target, target_weight) + + + + +# ----- NEW: losses when calculated on keypoint locations instead of keypoint heatmaps ----- + + +def joints_mse_loss_onKPloc(output, target, meta, target_weight=None): + # debugging: + # for old kornia version + # output_softmax_2d = dsnt.spatial_softmax_2d(target, temperature=torch.tensor(100)) + # output_kp = dsnt.spatial_softargmax_2d(output_softmax_2d, normalized_coordinates=False) + 1 + # print(output_kp[0]) + # print(meta['tpts'][0]) + # render gaussian + # dsnt.render_gaussian_2d(meta['tpts'][0][0, :2].to('cpu'), torch.tensor(([5., 5.])).to('cpu'), [256, 256], False) + # output_softmax_2d = dsnt.spatial_softmax_2d(output, temperature=torch.tensor(100)) + # target_norm = target / target.sum(axis=3).sum(axis=2)[:, :, None, None] + # output_softmax_2d = dsnt.spatial_softmax_2d(output*10) # (target, temperature=torch.tensor(10)) + # output_kp = dsnt.spatial_softargmax_2d(target_norm, normalized_coordinates=False) + 1 + + # normalize target heatmap + '''target_sum = target.sum(axis=3).sum(axis=2)[:, :, None, None] + target_sum[target_sum==0] = 1e-2 + target_norm = target / target_sum''' + target_norm = target # now we have normalized heatmaps + + # normalize predictions -> from logits to probability distribution + output_norm = dsnt.spatial_softmax2d(output, temperature=torch.tensor(1)) + + # heatmap loss (for normalization) + heatmap_loss = joints_mse_loss_orig(output_norm, target_norm, target_weight) + + # keypoint distance loss (average distance in pixels) + output_kp = dsnt.spatial_expectation2d(output_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) + target_kp = meta['tpts'].to(output_kp.device) # (bs, 20, 3) + output_kp_resh = output_kp.reshape((-1, 2)) + target_kp_resh = target_kp[:, :, :2].reshape((-1, 2)) + weights_resh = target_kp[:, :, 2].reshape((-1)) + # dist_loss = (((output_kp_resh - target_kp_resh)**2).sum(axis=1).sqrt()*weights_resh)[weights_resh>0].sum() / min(weights_resh[weights_resh>0].sum(), 1e-5) + dist_loss = (((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0]).sum() / max(weights_resh[weights_resh>0].sum(), 1e-5) + + + # return heatmap_loss*100 # + 0.0001*dist_loss + + # import pdb; pdb.set_trace() + + + '''import matplotlib as mpl + mpl.use('Agg') + import matplotlib.pyplot as plt + + img_np = output_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3] + img_np = img_np * 255./ img_np.max() + # plot image + plt.imshow(img_np) + plt.savefig('./debugging_output/test_output.png') + plt.close() + + img_np = target_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3] + img_np = img_np * 255./ img_np.max() + # plot image + plt.imshow(img_np) + plt.savefig('./debugging_output/test_gt.png') + plt.close()''' + + # print(heatmap_loss*100) + # print(dist_loss * 1e-4) + + # distlossonly: return dist_loss * 1e-4 + # both: return dist_loss * 1e-4 + heatmap_loss*100 + return dist_loss * 1e-4 + heatmap_loss*100 + + + + +class JointsMSELoss_onKPloc(nn.Module): + def __init__(self, use_target_weight=True): + super().__init__() + self.use_target_weight = use_target_weight + + def forward(self, output, target, target_weight): + if not self.use_target_weight: + target_weight = None + return joints_mse_loss_onKPloc(output, target, meta, target_weight) + + + + + +# ----- NEW: lsegmentation loss ----- + +import torch.nn.functional as F + +'''def resize2d(img, size): + return (F.adaptive_avg_pool2d(Variable(img,volatile=True), size)).data + # F.adaptive_avg_pool2d(meta['silh'], (64,64))).data''' + +def segmentation_loss(output, meta): + # output: (6, 2, 64, 64) + # meta.keys(): ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'silh'] + # prepare target silhouettes + target_silh = meta['silh'] + target_silh_l = target_silh.to(torch.long) + criterion_ce = nn.CrossEntropyLoss() + if output.shape[2] == 64: + target_silh_64 = F.adaptive_avg_pool2d(target_silh, (64,64)) + target_silh_64[target_silh_64>0.5] = 1 + target_silh_64[target_silh_64<=0.5] = 0 + target_silh_64_l = target_silh_64.to(torch.long) + loss_silh_64 = criterion_ce(output, target_silh_64_l) # 0.7 + return loss_silh_64 + else: + loss_silh_l = criterion_ce(output, target_silh_l) # 0.7 + return loss_silh_l + + + + + + diff --git a/src/stacked_hourglass/model.py b/src/stacked_hourglass/model.py new file mode 100644 index 0000000000000000000000000000000000000000..192e0411bd337415d29fd7a7e294a94d50b7fb93 --- /dev/null +++ b/src/stacked_hourglass/model.py @@ -0,0 +1,308 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose +# Hourglass network inserted in the pre-activated Resnet +# Use lr=0.01 for current version +# (c) YANG, Wei + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.hub import load_state_dict_from_url + + +__all__ = ['HourglassNet', 'hg'] + + +model_urls = { + 'hg1': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg1-ce125879.pth', + 'hg2': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg2-15e342d9.pth', + 'hg8': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg8-90e5d470.pth', +} + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=True) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + out = self.conv1(out) + + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + +class Hourglass(nn.Module): + def __init__(self, block, num_blocks, planes, depth): + super(Hourglass, self).__init__() + self.depth = depth + self.block = block + self.hg = self._make_hour_glass(block, num_blocks, planes, depth) + + def _make_residual(self, block, num_blocks, planes): + layers = [] + for i in range(0, num_blocks): + layers.append(block(planes*block.expansion, planes)) + return nn.Sequential(*layers) + + def _make_hour_glass(self, block, num_blocks, planes, depth): + hg = [] + for i in range(depth): + res = [] + for j in range(3): + res.append(self._make_residual(block, num_blocks, planes)) + if i == 0: + res.append(self._make_residual(block, num_blocks, planes)) + hg.append(nn.ModuleList(res)) + return nn.ModuleList(hg) + + def _hour_glass_forward(self, n, x): + up1 = self.hg[n-1][0](x) + low1 = F.max_pool2d(x, 2, stride=2) + low1 = self.hg[n-1][1](low1) + + if n > 1: + low2 = self._hour_glass_forward(n-1, low1) + else: + low2 = self.hg[n-1][3](low1) + low3 = self.hg[n-1][2](low2) + up2 = F.interpolate(low3, scale_factor=2) + out = up1 + up2 + return out + + def forward(self, x): + return self._hour_glass_forward(self.depth, x) + + +class HourglassNet(nn.Module): + '''Hourglass model from Newell et al ECCV 2016''' + def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + super(HourglassNet, self).__init__() + + self.inplanes = 64 + self.num_feats = 128 + self.num_stacks = num_stacks + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=True) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_residual(block, self.inplanes, 1) + self.layer2 = self._make_residual(block, self.inplanes, 1) + self.layer3 = self._make_residual(block, self.num_feats, 1) + self.maxpool = nn.MaxPool2d(2, stride=2) + self.upsample_seg = upsample_seg + self.add_partseg = add_partseg + + # build hourglass modules + ch = self.num_feats*block.expansion + hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] + for i in range(num_stacks): + hg.append(Hourglass(block, num_blocks, self.num_feats, 4)) + res.append(self._make_residual(block, self.num_feats, num_blocks)) + fc.append(self._make_fc(ch, ch)) + score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True)) + if i < num_stacks-1: + fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) + score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True)) + self.hg = nn.ModuleList(hg) + self.res = nn.ModuleList(res) + self.fc = nn.ModuleList(fc) + self.score = nn.ModuleList(score) + self.fc_ = nn.ModuleList(fc_) + self.score_ = nn.ModuleList(score_) + + if self.add_partseg: + self.hg_ps = (Hourglass(block, num_blocks, self.num_feats, 4)) + self.res_ps = (self._make_residual(block, self.num_feats, num_blocks)) + self.fc_ps = (self._make_fc(ch, ch)) + self.score_ps = (nn.Conv2d(ch, num_partseg, kernel_size=1, bias=True)) + self.ups_upsampling_ps = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + + + if self.upsample_seg: + self.ups_upsampling = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + self.ups_conv0 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3, + bias=True) + self.ups_bn1 = nn.BatchNorm2d(32) + self.ups_conv1 = nn.Conv2d(32, 16, kernel_size=7, stride=1, padding=3, + bias=True) + self.ups_bn2 = nn.BatchNorm2d(16+2) + self.ups_conv2 = nn.Conv2d(16+2, 16, kernel_size=5, stride=1, padding=2, + bias=True) + self.ups_bn3 = nn.BatchNorm2d(16) + self.ups_conv3 = nn.Conv2d(16, 2, kernel_size=5, stride=1, padding=2, + bias=True) + + + + def _make_residual(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=True), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_fc(self, inplanes, outplanes): + bn = nn.BatchNorm2d(inplanes) + conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True) + return nn.Sequential( + conv, + bn, + self.relu, + ) + + def forward(self, x_in): + out = [] + out_seg = [] + out_partseg = [] + x = self.conv1(x_in) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.maxpool(x) + x = self.layer2(x) + x = self.layer3(x) + + for i in range(self.num_stacks): + if i == self.num_stacks - 1: + if self.add_partseg: + y_ps = self.hg_ps(x) + y_ps = self.res_ps(y_ps) + y_ps = self.fc_ps(y_ps) + score_ps = self.score_ps(y_ps) + out_partseg.append(score_ps[:, :, :, :]) + y = self.hg[i](x) + y = self.res[i](y) + y = self.fc[i](y) + score = self.score[i](y) + if self.upsample_seg: + out.append(score[:, :-2, :, :]) + out_seg.append(score[:, -2:, :, :]) + else: + out.append(score) + if i < self.num_stacks-1: + fc_ = self.fc_[i](y) + score_ = self.score_[i](score) + x = x + fc_ + score_ + + if self.upsample_seg: + # PLAN: add a residual to the upsampled version of the segmentation image + # upsample predicted segmentation + seg_score = score[:, -2:, :, :] + seg_score_256 = self.ups_upsampling(seg_score) + # prepare input image + + ups_img = self.ups_conv0(x_in) + + ups_img = self.ups_bn1(ups_img) + ups_img = self.relu(ups_img) + ups_img = self.ups_conv1(ups_img) + + # import pdb; pdb.set_trace() + + ups_conc = torch.cat((seg_score_256, ups_img), 1) + + # ups_conc = self.ups_bn2(ups_conc) + ups_conc = self.relu(ups_conc) + ups_conc = self.ups_conv2(ups_conc) + + ups_conc = self.ups_bn3(ups_conc) + ups_conc = self.relu(ups_conc) + correction = self.ups_conv3(ups_conc) + + seg_final = seg_score_256 + correction + + if self.add_partseg: + partseg_final = self.ups_upsampling_ps(score_ps) + out_dict = {'out_list_kp': out, + 'out_list_seg': out_seg, + 'seg_final': seg_final, + 'out_list_partseg': out_partseg, + 'partseg_final': partseg_final + } + return out_dict + else: + out_dict = {'out_list_kp': out, + 'out_list_seg': out_seg, + 'seg_final': seg_final + } + return out_dict + + return out + + +def hg(**kwargs): + model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'], + num_classes=kwargs['num_classes'], upsample_seg=kwargs['upsample_seg'], + add_partseg=kwargs['add_partseg'], num_partseg=kwargs['num_partseg']) + return model + + +def _hg(arch, pretrained, progress, **kwargs): + model = hg(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def hg1(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg1', pretrained, progress, num_stacks=1, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) + + +def hg2(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg2', pretrained, progress, num_stacks=2, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) + +def hg4(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg4', pretrained, progress, num_stacks=4, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) + +def hg8(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): + return _hg('hg8', pretrained, progress, num_stacks=8, num_blocks=num_blocks, + num_classes=num_classes, upsample_seg=upsample_seg, + add_partseg=add_partseg, num_partseg=num_partseg) diff --git a/src/stacked_hourglass/predictor.py b/src/stacked_hourglass/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d511a9668547cad4a21526aa514589705e06e6f8 --- /dev/null +++ b/src/stacked_hourglass/predictor.py @@ -0,0 +1,122 @@ + +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import torch +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) +from src.stacked_hourglass.utils.evaluation import final_preds_untransformed +from src.stacked_hourglass.utils.imfit import fit, calculate_fit_contain_output_area +from src.stacked_hourglass.utils.transforms import color_normalize, fliplr, flip_back + + +def _check_batched(images): + if isinstance(images, (tuple, list)): + return True + if images.ndimension() == 4: + return True + return False + + +class HumanPosePredictor: + def __init__(self, model, device=None, data_info=None, input_shape=None): + """Helper class for predicting 2D human pose joint locations. + + Args: + model: The model for generating joint heatmaps. + device: The computational device to use for inference. + data_info: Specifications of the data (defaults to ``Mpii.DATA_INFO``). + input_shape: The input dimensions of the model (height, width). + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = torch.device(device) + model.to(device) + self.model = model + self.device = device + + if data_info is None: + raise ValueError + # self.data_info = Mpii.DATA_INFO + else: + self.data_info = data_info + + # Input shape ordering: H, W + if input_shape is None: + self.input_shape = (256, 256) + elif isinstance(input_shape, int): + self.input_shape = (input_shape, input_shape) + else: + self.input_shape = input_shape + + def do_forward(self, input_tensor): + self.model.eval() + with torch.no_grad(): + output = self.model(input_tensor) + return output + + def prepare_image(self, image): + was_fixed_point = not image.is_floating_point() + image = torch.empty_like(image, dtype=torch.float32).copy_(image) + if was_fixed_point: + image /= 255.0 + if image.shape[-2:] != self.input_shape: + image = fit(image, self.input_shape, fit_mode='contain') + image = color_normalize(image, self.data_info.rgb_mean, self.data_info.rgb_stddev) + return image + + def estimate_heatmaps(self, images, flip=False): + is_batched = _check_batched(images) + raw_images = images if is_batched else images.unsqueeze(0) + input_tensor = torch.empty((len(raw_images), 3, *self.input_shape), + device=self.device, dtype=torch.float32) + for i, raw_image in enumerate(raw_images): + input_tensor[i] = self.prepare_image(raw_image) + heatmaps = self.do_forward(input_tensor)[-1].cpu() + if flip: + flip_input = fliplr(input_tensor) + flip_heatmaps = self.do_forward(flip_input)[-1].cpu() + heatmaps += flip_back(flip_heatmaps, self.data_info.hflip_indices) + heatmaps /= 2 + if is_batched: + return heatmaps + else: + return heatmaps[0] + + def estimate_joints(self, images, flip=False): + """Estimate human joint locations from input images. + + Images are expected to be centred on a human subject and scaled reasonably. + + Args: + images: The images to estimate joint locations for. Can be a single image or a list + of images. + flip (bool): If set to true, evaluates on flipped versions of the images as well and + averages the results. + + Returns: + The predicted human joint locations in image pixel space. + """ + is_batched = _check_batched(images) + raw_images = images if is_batched else images.unsqueeze(0) + heatmaps = self.estimate_heatmaps(raw_images, flip=flip).cpu() + # final_preds_untransformed compares the first component of shape with x and second with y + # This relates to the image Width, Height (Heatmap has shape Height, Width) + coords = final_preds_untransformed(heatmaps, heatmaps.shape[-2:][::-1]) + # Rescale coords to pixel space of specified images. + for i, image in enumerate(raw_images): + # When returning to original image space we need to compensate for the fact that we are + # used fit_mode='contain' when preparing the images for inference. + y_off, x_off, height, width = calculate_fit_contain_output_area(*image.shape[-2:], *self.input_shape) + coords[i, :, 1] *= self.input_shape[-2] / heatmaps.shape[-2] + coords[i, :, 1] -= y_off + coords[i, :, 1] *= image.shape[-2] / height + coords[i, :, 0] *= self.input_shape[-1] / heatmaps.shape[-1] + coords[i, :, 0] -= x_off + coords[i, :, 0] *= image.shape[-1] / width + if is_batched: + return coords + else: + return coords[0] diff --git a/src/stacked_hourglass/train.py b/src/stacked_hourglass/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ce3f2b31ef8c6f60a2400f1f2ea05cb764b6f94c --- /dev/null +++ b/src/stacked_hourglass/train.py @@ -0,0 +1,210 @@ + +# scripts/train.py --workers 12 --checkpoint project22_no3dcgloss_smaldogsilvia_v0 --loss-weight-path barc_loss_weights_no3dcgloss.json --config barc_cfg_train.yaml start --model-file-hg hg_ksp_fromnewanipose_stanext_v0/checkpoint.pth.tar --model-file-3d barc_normflow_pret/checkpoint.pth.tar + +import torch +import torch.backends.cudnn +import torch.nn.parallel +from tqdm import tqdm +import os +import json +import pathlib + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../', 'src')) +# from stacked_hourglass.loss import joints_mse_loss +from stacked_hourglass.loss import joints_mse_loss_onKPloc +from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft +from stacked_hourglass.utils.transforms import fliplr, flip_back +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints + + +def do_training_step(model, optimiser, input, target, meta, data_info, target_weight=None): + assert model.training, 'model must be in training mode.' + assert len(input) == len(target), 'input and target must contain the same number of examples.' + + with torch.enable_grad(): + # Forward pass and loss calculation. + output = model(input) + + # original: loss = sum(joints_mse_loss(o, target, target_weight) for o in output) + # NEW: + loss = sum(joints_mse_loss_onKPloc(o, target, meta, target_weight) for o in output) + + # Backward pass and parameter update. + optimiser.zero_grad() + loss.backward() + optimiser.step() + + return output[-1], loss.item() + + +def do_training_epoch(train_loader, model, device, data_info, optimiser, quiet=False, acc_joints=None): + losses = AverageMeter() + accuracies = AverageMeter() + + # Put the model in training mode. + model.train() + + iterable = enumerate(train_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False) + iterable = progress + + for i, (input, target, meta) in iterable: + input, target = input.to(device), target.to(device, non_blocking=True) + target_weight = meta['target_weight'].to(device, non_blocking=True) + + output, loss = do_training_step(model, optimiser, input, target, meta, data_info, target_weight) + + acc = accuracy(output, target, acc_joints) + + # measure accuracy and record loss + losses.update(loss, input.size(0)) + accuracies.update(acc[0], input.size(0)) + + # Show accuracy and loss as part of the progress bar. + if progress is not None: + progress.set_postfix_str('Loss: {loss:0.4f}, Acc: {acc:6.2f}'.format( + loss=losses.avg, + acc=100 * accuracies.avg + )) + + return losses.avg, accuracies.avg + + +def do_validation_step(model, input, target, meta, data_info, target_weight=None, flip=False): + # assert not model.training, 'model must be in evaluation mode.' + assert len(input) == len(target), 'input and target must contain the same number of examples.' + + # Forward pass and loss calculation. + output = model(input) + + # original: loss = sum(joints_mse_loss(o, target, target_weight) for o in output) + # NEW: + loss = sum(joints_mse_loss_onKPloc(o, target, meta, target_weight) for o in output) + + + # Get the heatmaps. + if flip: + # If `flip` is true, perform horizontally flipped inference as well. This should + # result in more robust predictions at the expense of additional compute. + flip_input = fliplr(input) + flip_output = model(flip_input) + flip_output = flip_output[-1].cpu() + flip_output = flip_back(flip_output.detach(), data_info.hflip_indices) + heatmaps = (output[-1].cpu() + flip_output) / 2 + else: + heatmaps = output[-1].cpu() + + + return heatmaps, loss.item() + + +def do_validation_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None): + losses = AverageMeter() + accuracies = AverageMeter() + predictions = [None] * len(val_loader.dataset) + + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + + # Put the model in evaluation mode. + model.eval() + + iterable = enumerate(val_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False) + iterable = progress + + for i, (input, target, meta) in iterable: + # Copy data to the training device (eg GPU). + input = input.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + target_weight = meta['target_weight'].to(device, non_blocking=True) + + # import pdb; pdb.set_trace() + + heatmaps, loss = do_validation_step(model, input, target, meta, data_info, target_weight, flip) + + # Calculate PCK from the predicted heatmaps. + acc = accuracy(heatmaps, target.cpu(), acc_joints) + + # Calculate locations in original image space from the predicted heatmaps. + preds = final_preds(heatmaps, meta['center'], meta['scale'], [64, 64]) + # NEW for visualization: (and redundant, but for visualization) + preds_unprocessed, preds_unprocessed_maxval = get_preds_soft(heatmaps, return_maxval=True) + # preds_unprocessed, preds_unprocessed_norm, preds_unprocessed_maxval = get_preds_soft(heatmaps, return_maxval=True, norm_and_unnorm_coords=True) + + + # import pdb; pdb.set_trace() + + ind = 0 + for example_index, pose in zip(meta['index'], preds): + predictions[example_index] = pose + # NEW for visualization + if save_imgs_path is not None: + out_name = os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.png') + pred_unp = preds_unprocessed[ind, :, :] + + pred_unp_maxval = preds_unprocessed_maxval[ind, :, :] + pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) + inp_img = input[ind, :, :, :] + # the following line (with -1) should not be needed anymore after cvpr (after bugfix01 in data preparation 08.09.2022) + # pred_unp_prep[:, :2] = pred_unp_prep[:, :2] - 1 + # save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_name, threshold=0.1, print_scores=True) # here we have default ratio_in_out=4. + + # NEW: 08.09.2022 after bugfix01 + + # import pdb; pdb.set_trace() + + pred_unp_prep[:, :2] = pred_unp_prep[:, :2] * 4 + + if 'name' in meta.keys(): # we do this for the stanext set + name = meta['name'][ind] + out_path_keyp_img = os.path.join(os.path.dirname(out_name), name) + out_path_json = os.path.join(os.path.dirname(out_name), name).replace('_vis', '_json').replace('.jpg', '.json') + if not os.path.exists(os.path.dirname(out_path_json)): + os.makedirs(os.path.dirname(out_path_json)) + if not os.path.exists(os.path.dirname(out_path_keyp_img)): + os.makedirs(os.path.dirname(out_path_keyp_img)) + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path_keyp_img, ratio_in_out=1.0, threshold=0.1, print_scores=True) # threshold=0.3 + out_name_json = out_path_json # os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.json') + res_dict = { + 'pred_joints_256': list(pred_unp_prep.cpu().numpy().astype(float).reshape((-1))), + 'center': list(meta['center'][ind, :].cpu().numpy().astype(float).reshape((-1))), + 'scale': meta['scale'][ind].item()} + with open(out_name_json, 'w') as outfile: json.dump(res_dict, outfile) + else: + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_name, ratio_in_out=1.0, threshold=0.1, print_scores=True) # threshold=0.3 + + + + '''# animalpose_hg8_v0 (did forget to subtract 1 in dataset) + pred_unp_prep[:, :2] = pred_unp_prep[:, :2] * 4 ############ Why is this necessary??? + pred_unp_prep[:, :2] = pred_unp_prep[:, :2] - 1 + save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_name, ratio_in_out=1.0, threshold=0.1, print_scores=True) # threshold=0.3 + out_name_json = os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.json') + res_dict = { + 'pred_joints_256': list(pred_unp_prep.cpu().numpy().astype(float).reshape((-1))), + 'center': list(meta['center'][ind, :].cpu().numpy().astype(float).reshape((-1))), + 'scale': meta['scale'][ind].item()} + with open(out_name_json, 'w') as outfile: json.dump(res_dict, outfile)''' + + ind += 1 + + # Record accuracy and loss for this batch. + losses.update(loss, input.size(0)) + accuracies.update(acc[0].item(), input.size(0)) + + # Show accuracy and loss as part of the progress bar. + if progress is not None: + progress.set_postfix_str('Loss: {loss:0.4f}, Acc: {acc:6.2f}'.format( + loss=losses.avg, + acc=100 * accuracies.avg + )) + + predictions = torch.stack(predictions, dim=0) + + return losses.avg, accuracies.avg, predictions diff --git a/src/stacked_hourglass/train_ksp.py b/src/stacked_hourglass/train_ksp.py new file mode 100644 index 0000000000000000000000000000000000000000..fef8b911358b63517bfb1275f681ea966c4ddd68 --- /dev/null +++ b/src/stacked_hourglass/train_ksp.py @@ -0,0 +1,304 @@ + +''' +in contrast to train.py, here we do not only predict keypoints but instead: + - keypoints + - segmentation +''' + +import torch +import torch.backends.cudnn +import torch.nn.parallel +import torch.nn as nn +from tqdm import tqdm +import os +import pathlib +from matplotlib import pyplot as plt +import numpy as np +import cv2 +import pickle as pkl + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +# from stacked_hourglass.loss import joints_mse_loss +from stacked_hourglass.loss import joints_mse_loss_onKPloc, segmentation_loss +from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft +from stacked_hourglass.utils.transforms import fliplr, flip_back +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_image_with_part_segmentation, save_image_with_part_segmentation_from_gt_annotation + + + +def do_training_step(model, optimiser, input, target, meta, data_info, target_weight=None): + assert model.training, 'model must be in training mode.' + assert len(input) == len(target), 'input and target must contain the same number of examples.' + + with torch.enable_grad(): + + # import pdb; pdb.set_trace() + + # Forward pass and loss calculation. + # output = model(input) # this is a list + '''output = out_dict['out_list']''' + # dict_keys(['out_list_kp', 'out_list_seg', 'seg_final', 'out_list_partseg', 'partseg_final']) + out_dict = model(input) + + # original: loss = sum(joints_mse_loss(o, target, target_weight) for o in output) + '''loss_kp = sum(joints_mse_loss_onKPloc(o[:, :-2, :, :], target, meta, target_weight) for o in output) + loss_seg = sum(segmentation_loss(o[:, -2:, :, :], meta) for o in output)''' + loss_kp = sum(joints_mse_loss_onKPloc(o, target, meta, target_weight) for o in out_dict['out_list_kp']) + loss_seg = sum(segmentation_loss(o, meta) for o in out_dict['out_list_seg']) + loss_seg_big = segmentation_loss(out_dict['seg_final'], meta) + + # NEW for body part segmentation + + '''import pdb; pdb.set_trace() + for ind_gt in range(6, 12): + out_path_gt_seg = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/debugging_output/partseg/gt_' + str(ind_gt) + '.png' + save_image_with_part_segmentation_from_gt_annotation(meta['body_part_matrix'].detach().cpu().numpy(), out_path_gt_seg, ind_gt)''' + + + # for the second stage where we add a dataset with body part segmentations + # and not just fake -1 labels, we calculate body part segmentation loss as well + # if all body part labels are -1, we ignore this loss calculation + if meta['body_part_matrix'].max() > -1: # this will be the case for dogsvoc but not stanext + tbp_dict = {'full_body': [0, 8], + 'head': [8, 13], + 'torso': [13, 15]} + loss_partseg = [] + criterion_ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1) + ''''weights = [5.0, 1.0, 1.0, 1.0, 1.0] + class_weights = torch.FloatTensor(weights).to(input.device) + criterion_ce_weighted = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1, weight=class_weights) + for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): + tbp_out = out_dict['partseg_final'][:, tbp_dict[part][0]:tbp_dict[part][1], :, :] + tbp_target = meta['body_part_matrix'][:, ind_tbp, :, :].to(torch.long) + if part == 'head': + loss_partseg.append(criterion_ce_weighted(tbp_out, tbp_target)) + else: + loss_partseg.append(criterion_ce(tbp_out, tbp_target))''' + for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): + tbp_out = out_dict['partseg_final'][:, tbp_dict[part][0]:tbp_dict[part][1], :, :] + tbp_target = meta['body_part_matrix'][:, ind_tbp, :, :].to(torch.long) + if part == 'full_body': + # ignore parts of the silhouette which dont have a specific body part label + tbp_target[tbp_target==0] = -1 + loss_partseg.append(criterion_ce(tbp_out, tbp_target)) + else: + loss_partseg.append(criterion_ce(tbp_out, tbp_target)) + # print(loss_seg_big) + # print(loss_partseg) + + # loss = loss_kp + loss_seg*0.01 + loss_seg_big*0.1 # orig # 0.001 # 0.01 + loss = loss_kp + loss_seg*0.001 + loss_seg_big*0.01 + 0.01*(loss_partseg[0] + loss_partseg[1] + loss_partseg[2]) + + else: + loss = loss_kp + loss_seg*0.01 + loss_seg_big*0.1 + + + + + # Backward pass and parameter update. + optimiser.zero_grad() + loss.backward() + optimiser.step() + + loss_dict = {'loss': loss.item(), + 'keyp': loss_kp.item(), + 'seg': loss_seg.item(), + 'seg_big': loss_seg_big.item() + } + + return out_dict['out_list_kp'][-1], loss_dict + + +def do_training_epoch(train_loader, model, device, data_info, optimiser, quiet=False, acc_joints=None): + losses = AverageMeter() + accuracies = AverageMeter() + + # Put the model in training mode. + model.train() + + iterable = enumerate(train_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False) + iterable = progress + + for i, (input, target, meta) in iterable: + input, target = input.to(device), target.to(device, non_blocking=True) + target_weight = meta['target_weight'].to(device, non_blocking=True) + meta['silh'] = meta['silh'].to(device, non_blocking=True) + meta['body_part_matrix'] = meta['body_part_matrix'].to(device, non_blocking=True) + + output_kp, loss_dict = do_training_step(model, optimiser, input, target, meta, data_info, target_weight) + loss = loss_dict['loss'] + + acc = accuracy(output_kp, target, acc_joints) + + # measure accuracy and record loss + losses.update(loss, input.size(0)) + accuracies.update(acc[0], input.size(0)) + + # Show accuracy and loss as part of the progress bar. + if progress is not None: + progress.set_postfix_str('Loss: {loss:0.4f}, Acc: {acc:6.2f}'.format( + loss=losses.avg, + acc=100 * accuracies.avg + )) + + return losses.avg, accuracies.avg + + +def do_validation_step(model, input, target, meta, data_info, target_weight=None, flip=False): + assert not model.training, 'model must be in evaluation mode.' + assert len(input) == len(target), 'input and target must contain the same number of examples.' + + # Forward pass and loss calculation. + # output = model(input) + out_dict = model(input) # ['out_list', 'seg_final'] + '''output = out_dict['out_list']''' + + # original: loss = sum(joints_mse_loss(o, target, target_weight) for o in output) + '''loss_kp = sum(joints_mse_loss_onKPloc(o[:, :-2, :, :], target, meta, target_weight) for o in output) + loss_seg = sum(segmentation_loss(o[:, -2:, :, :], meta) for o in output)''' + loss_kp = sum(joints_mse_loss_onKPloc(o, target, meta, target_weight) for o in out_dict['out_list_kp']) + loss_seg = sum(segmentation_loss(o, meta) for o in out_dict['out_list_seg']) + loss_seg_big = segmentation_loss(out_dict['seg_final'], meta) + loss = loss_kp + loss_seg*0.01 + loss_seg_big*0.1 # 0.001 # 0.01 + + # Get the heatmaps. + heatmaps = out_dict['out_list_kp'][-1].cpu() + + '''seg = output[-1][:, -2:, :, :].cpu()''' + seg = out_dict['out_list_seg'][-1].cpu() + seg_big = out_dict['seg_final'].cpu() + partseg_big = out_dict['partseg_final'].cpu() + + loss_dict = {'loss': loss.item(), + 'keyp': loss_kp.item(), + 'seg': loss_seg.item(), + 'seg_big': loss_seg_big.item() + } + + return heatmaps, seg, seg_big, partseg_big, loss_dict # loss.item() + + +def do_validation_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, save_pkl_path=None): + losses = AverageMeter() + accuracies = AverageMeter() + predictions = [None] * len(val_loader.dataset) + + if save_imgs_path is not None: + pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True) + + # Put the model in evaluation mode. + model.eval() + + iterable = enumerate(val_loader) + progress = None + if not quiet: + progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False) + iterable = progress + + for i, (input, target, meta) in iterable: + # Copy data to the training device (eg GPU). + input = input.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + target_weight = meta['target_weight'].to(device, non_blocking=True) + meta['silh'] = meta['silh'].to(device, non_blocking=True) + if 'body_part_matrix' in meta.keys(): + meta['body_part_matrix'] = meta['body_part_matrix'].to(device, non_blocking=True) + + heatmaps, seg, seg_big, partseg_big, loss_dict = do_validation_step(model, input, target, meta, data_info, target_weight, flip) + loss = loss_dict['loss'] + + # Calculate PCK from the predicted heatmaps. + acc = accuracy(heatmaps, target.cpu(), acc_joints) + + # Calculate locations in original image space from the predicted heatmaps. + preds = final_preds(heatmaps, meta['center'], meta['scale'], [64, 64]) + # NEW for visualization: (and redundant, but for visualization) + if (save_imgs_path is not None) or (save_pkl_path is not None): + preds_unprocessed, preds_unprocessed_norm, preds_unprocessed_maxval = get_preds_soft(heatmaps, return_maxval=True, norm_and_unnorm_coords=True) + + # import pdb; pdb.set_trace() + + + ind = 0 + for example_index, pose in zip(meta['index'], preds): + # prepare save paths + if save_pkl_path is not None: + out_name_seg_overlay = os.path.join(save_imgs_path, meta['name'][ind].replace('.jpg', '__') + 'seg_overlay.png') + out_name_kp = os.path.join(save_imgs_path, meta['name'][ind].replace('.jpg', '__') + 'res.png') + if not os.path.exists(os.path.dirname(out_name_kp)): + os.makedirs(os.path.dirname(out_name_kp)) + out_name_pkl = os.path.join(save_pkl_path, meta['name'][ind].replace('.jpg', '.pkl')) + if not os.path.exists(os.path.dirname(out_name_pkl)): + os.makedirs(os.path.dirname(out_name_pkl)) + else: + if save_imgs_path is not None: + out_name_seg_overlay = os.path.join(save_imgs_path, 'seg_overlay_' + str( example_index.item()) + '.png') + out_name_kp = os.path.join(save_imgs_path, 'res_' + str( example_index.item()) + '.png') + predictions[example_index] = pose + # NEW for visualization + if save_imgs_path is not None: + soft_max = torch.nn.Softmax(dim= 0) + segm_img_pred = soft_max((seg_big[ind, :, :, :]))[1, :, :] + if save_pkl_path is None: + # save segmentation image + out_name_seg = os.path.join(save_imgs_path, 'seg_' + str( example_index.item()) + '.png') + segm_img_pred_small = soft_max((seg[ind, :, :, :]))[1, :, :] + plt.imsave(out_name_seg, segm_img_pred_small) + # save segmentation image + out_name_seg = os.path.join(save_imgs_path, 'seg_big_' + str( example_index.item()) + '.png') + plt.imsave(out_name_seg, segm_img_pred) + # segmentation overlay + input_image = input[ind, :, :, :].detach().clone() + for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) + input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) + thr = 0.3 + segm_img_pred[segm_img_pred>thr] = 1 + segm_img_pred_3 = np.stack([segm_img_pred, np.zeros((256, 256), dtype=np.float32), np.zeros((256, 256), dtype=np.float32)], axis=2) + segm_img_pred_3[segm_img_pred 0 + preds *= pred_mask + if return_maxval: + return preds, maxval + else: + return preds + + +def get_preds_soft(scores, return_maxval=False, norm_coords=False, norm_and_unnorm_coords=False): + ''' get predictions from score maps in torch Tensor + predictions are made assuming a logit output map + return type: torch.LongTensor + ''' + + # New: work on logit predictions + scores_norm = dsnt.spatial_softmax2d(scores, temperature=torch.tensor(1)) + # maxval_norm, idx_norm = torch.max(scores_norm.view(scores.size(0), scores.size(1), -1), 2) + # from unnormalized to normalized see: + # from -1to1 to 0to64 + # see https://github.com/kornia/kornia/blob/b9ffe7efcba7399daeeb8028f10c22941b55d32d/kornia/utils/grid.py#L7 (line 40) + # xs = (xs / (width - 1) - 0.5) * 2 + # ys = (ys / (height - 1) - 0.5) * 2 + + device = scores.device + + if return_maxval: + preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) + # grid_sample(input, grid, mode='bilinear', padding_mode='zeros') + gs_input_single = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) + gs_input = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) + + half_pad = 2 + gs_input_single_padded = F.pad(input=gs_input_single, pad=(half_pad, half_pad, half_pad, half_pad, 0, 0, 0, 0), mode='constant', value=0) + gs_input_all = torch.zeros((gs_input_single.shape[0], 9, gs_input_single.shape[2], gs_input_single.shape[3])).to(device) + ind_tot = 0 + for ind0 in [-1, 0, 1]: + for ind1 in [-1, 0, 1]: + gs_input_all[:, ind_tot, :, :] = gs_input_single_padded[:, 0, half_pad+ind0:-half_pad+ind0, half_pad+ind1:-half_pad+ind1] + ind_tot +=1 + + gs_grid = preds_normalized.reshape((-1, 2))[:, None, None, :] # (120, 1, 1, 2) + gs_output_all = F.grid_sample(gs_input_all, gs_grid, mode='nearest', padding_mode='zeros', align_corners=True).reshape((gs_input_all.shape[0], gs_input_all.shape[1], 1)) + gs_output = gs_output_all.sum(axis=1) + # scores_norm[0, :, :, :].max(axis=2)[0].max(axis=1)[0] + # gs_output[0, :, 0] + gs_output_resh = gs_output.reshape((scores_norm.shape[0], scores_norm.shape[1], 1)) + + if norm_and_unnorm_coords: + preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 + return preds, preds_normalized, gs_output_resh + elif norm_coords: + return preds_normalized, gs_output_resh + else: + preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 + return preds, gs_output_resh + else: + if norm_coords: + preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) + return preds_normalized + else: + preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 + return preds + + +def calc_dists(preds, target, normalize): + preds = preds.float() + target = target.float() + dists = torch.zeros(preds.size(1), preds.size(0)) + for n in range(preds.size(0)): + for c in range(preds.size(1)): + if target[n,c,0] > 1 and target[n, c, 1] > 1: + dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] + else: + dists[c, n] = -1 + return dists + +def dist_acc(dist, thr=0.5): + ''' Return percentage below threshold while ignoring values with a -1 ''' + dist = dist[dist != -1] + if len(dist) > 0: + return 1.0 * (dist < thr).sum().item() / len(dist) + else: + return -1 + +def accuracy(output, target, idxs=None, thr=0.5): + ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations + First value to be returned is average accuracy across 'idxs', followed by individual accuracies + ''' + if idxs is None: + idxs = list(range(target.shape[-3])) + preds = get_preds_soft(output) # get_preds(output) + gts = get_preds(target) + norm = torch.ones(preds.size(0))*output.size(3)/10 + dists = calc_dists(preds, gts, norm) + + acc = torch.zeros(len(idxs)+1) + avg_acc = 0 + cnt = 0 + + for i in range(len(idxs)): + acc[i+1] = dist_acc(dists[idxs[i]], thr=thr) + if acc[i+1] >= 0: + avg_acc = avg_acc + acc[i+1] + cnt += 1 + + if cnt != 0: + acc[0] = avg_acc / cnt + return acc + +def final_preds_untransformed(output, res): + coords = get_preds_soft(output) # get_preds(output) # float type + + # pose-processing + for n in range(coords.size(0)): + for p in range(coords.size(1)): + hm = output[n][p] + px = int(math.floor(coords[n][p][0])) + py = int(math.floor(coords[n][p][1])) + if px > 1 and px < res[0] and py > 1 and py < res[1]: + diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) + coords[n][p] += diff.sign() * .25 + coords += 0.5 + + if coords.dim() < 3: + coords = coords.unsqueeze(0) + + coords -= 1 # Convert from 1-based to 0-based coordinates + + return coords + +def final_preds(output, center, scale, res): + coords = final_preds_untransformed(output, res) + preds = coords.clone() + + # Transform back + for i in range(coords.size(0)): + preds[i] = transform_preds(coords[i], center[i], scale[i], res) + + if preds.dim() < 3: + preds = preds.unsqueeze(0) + + return preds + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/src/stacked_hourglass/utils/finetune.py b/src/stacked_hourglass/utils/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..e7990b26a90e824f02141d7908907679f544f98c --- /dev/null +++ b/src/stacked_hourglass/utils/finetune.py @@ -0,0 +1,39 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import torch +from torch.nn import Conv2d, ModuleList + + +def change_hg_outputs(model, indices): + """Change the output classes of the model. + + Args: + model: The model to modify. + indices: An array of indices describing the new model outputs. For example, [3, 4, None] + will modify the model to have 3 outputs, the first two of which have parameters + copied from the fourth and fifth outputs of the original model. + """ + with torch.no_grad(): + new_n_outputs = len(indices) + new_score = ModuleList() + for conv in model.score: + new_conv = Conv2d(conv.in_channels, new_n_outputs, conv.kernel_size, conv.stride) + new_conv = new_conv.to(conv.weight.device, conv.weight.dtype) + for i, index in enumerate(indices): + if index is not None: + new_conv.weight[i] = conv.weight[index] + new_conv.bias[i] = conv.bias[index] + new_score.append(new_conv) + model.score = new_score + new_score_ = ModuleList() + for conv in model.score_: + new_conv = Conv2d(new_n_outputs, conv.out_channels, conv.kernel_size, conv.stride) + new_conv = new_conv.to(conv.weight.device, conv.weight.dtype) + for i, index in enumerate(indices): + if index is not None: + new_conv.weight[:, i] = conv.weight[:, index] + new_conv.bias = conv.bias + new_score_.append(new_conv) + model.score_ = new_score_ diff --git a/src/stacked_hourglass/utils/imfit.py b/src/stacked_hourglass/utils/imfit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0d2e131bf3c1bd2e0c740d9c8cfd9d847f523d --- /dev/null +++ b/src/stacked_hourglass/utils/imfit.py @@ -0,0 +1,144 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import torch +from torch.nn.functional import interpolate + + +def _resize(tensor, size, mode='bilinear'): + """Resize the image. + + Args: + tensor (torch.Tensor): The image tensor to be resized. + size (tuple of int): Size of the resized image (height, width). + mode (str): The pixel sampling interpolation mode to be used. + + Returns: + Tensor: The resized image tensor. + """ + assert len(size) == 2 + + # If the tensor is already the desired size, return it immediately. + if tensor.shape[-2] == size[0] and tensor.shape[-1] == size[1]: + return tensor + + if not tensor.is_floating_point(): + dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor = _resize(tensor, size, mode) + return tensor.to(dtype) + + out_shape = (*tensor.shape[:-2], *size) + if tensor.ndimension() < 3: + raise Exception('tensor must be at least 2D') + elif tensor.ndimension() == 3: + tensor = tensor.unsqueeze(0) + elif tensor.ndimension() > 4: + tensor = tensor.view(-1, *tensor.shape[-3:]) + align_corners = None + if mode in {'linear', 'bilinear', 'trilinear'}: + align_corners = False + resized = interpolate(tensor, size=size, mode=mode, align_corners=align_corners) + return resized.view(*out_shape) + + +def _crop(tensor, t, l, h, w, padding_mode='constant', fill=0): + """Crop the image, padding out-of-bounds regions. + + Args: + tensor (torch.Tensor): The image tensor to be cropped. + t (int): Top pixel coordinate. + l (int): Left pixel coordinate. + h (int): Height of the cropped image. + w (int): Width of the cropped image. + padding_mode (str): Padding mode (currently "constant" is the only valid option). + fill (float): Fill value to use with constant padding. + + Returns: + Tensor: The cropped image tensor. + """ + # If the _crop region is wholly within the image, simply narrow the tensor. + if t >= 0 and l >= 0 and t + h <= tensor.size(-2) and l + w <= tensor.size(-1): + return tensor[..., t:t+h, l:l+w] + + if padding_mode == 'constant': + result = torch.full((*tensor.size()[:-2], h, w), fill, + device=tensor.device, dtype=tensor.dtype) + else: + raise Exception('_crop only supports "constant" padding currently.') + + sx1 = l + sy1 = t + sx2 = l + w + sy2 = t + h + dx1 = 0 + dy1 = 0 + + if sx1 < 0: + dx1 = -sx1 + w += sx1 + sx1 = 0 + + if sy1 < 0: + dy1 = -sy1 + h += sy1 + sy1 = 0 + + if sx2 >= tensor.size(-1): + w -= sx2 - tensor.size(-1) + + if sy2 >= tensor.size(-2): + h -= sy2 - tensor.size(-2) + + # Copy the in-bounds sub-area of the _crop region into the result tensor. + if h > 0 and w > 0: + src = tensor.narrow(-2, sy1, h).narrow(-1, sx1, w) + dst = result.narrow(-2, dy1, h).narrow(-1, dx1, w) + dst.copy_(src) + + return result + + +def calculate_fit_contain_output_area(in_height, in_width, out_height, out_width): + ih, iw = in_height, in_width + k = min(out_width / iw, out_height / ih) + oh = round(k * ih) + ow = round(k * iw) + y_off = (out_height - oh) // 2 + x_off = (out_width - ow) // 2 + return y_off, x_off, oh, ow + + +def fit(tensor, size, fit_mode='cover', resize_mode='bilinear', *, fill=0): + """Fit the image within the given spatial dimensions. + + Args: + tensor (torch.Tensor): The image tensor to be fit. + size (tuple of int): Size of the output (height, width). + fit_mode (str): 'fill', 'contain', or 'cover'. These behave in the same way as CSS's + `object-fit` property. + fill (float): padding value (only applicable in 'contain' mode). + + Returns: + Tensor: The resized image tensor. + """ + if fit_mode == 'fill': + return _resize(tensor, size, mode=resize_mode) + elif fit_mode == 'contain': + y_off, x_off, oh, ow = calculate_fit_contain_output_area(*tensor.shape[-2:], *size) + resized = _resize(tensor, (oh, ow), mode=resize_mode) + result = tensor.new_full((*tensor.size()[:-2], *size), fill) + result[..., y_off:y_off + oh, x_off:x_off + ow] = resized + return result + elif fit_mode == 'cover': + ih, iw = tensor.shape[-2:] + k = max(size[-1] / iw, size[-2] / ih) + oh = round(k * ih) + ow = round(k * iw) + resized = _resize(tensor, (oh, ow), mode=resize_mode) + y_trim = (oh - size[-2]) // 2 + x_trim = (ow - size[-1]) // 2 + result = _crop(resized, y_trim, x_trim, size[-2], size[-1]) + return result + raise ValueError('Invalid fit_mode: ' + repr(fit_mode)) diff --git a/src/stacked_hourglass/utils/imutils.py b/src/stacked_hourglass/utils/imutils.py new file mode 100644 index 0000000000000000000000000000000000000000..5540728cc9f85e55b560308417c3b77d9c678a13 --- /dev/null +++ b/src/stacked_hourglass/utils/imutils.py @@ -0,0 +1,125 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import numpy as np + +from .misc import to_numpy, to_torch +from .pilutil import imread, imresize +from kornia.geometry.subpix import dsnt +import torch + +def im_to_numpy(img): + img = to_numpy(img) + img = np.transpose(img, (1, 2, 0)) # H*W*C + return img + +def im_to_torch(img): + img = np.transpose(img, (2, 0, 1)) # C*H*W + img = to_torch(img).float() + if img.max() > 1: + img /= 255 + return img + +def load_image(img_path): + # H x W x C => C x H x W + return im_to_torch(imread(img_path, mode='RGB')) + +# ============================================================================= +# Helpful functions generating groundtruth labelmap +# ============================================================================= + +def gaussian(shape=(7,7),sigma=1): + """ + 2D gaussian mask - should give the same result as MATLAB's + fspecial('gaussian',[shape],[sigma]) + """ + m,n = [(ss-1.)/2. for ss in shape] + y,x = np.ogrid[-m:m+1,-n:n+1] + h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) + h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 + return to_torch(h).float() + +def draw_labelmap_orig(img, pt, sigma, type='Gaussian'): + # Draw a 2D gaussian + # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py + # maximum value of the gaussian is 1 + img = to_numpy(img) + + # Check that any part of the gaussian is in-bounds + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] + if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or + br[0] < 0 or br[1] < 0): + # If not, just return the image as is + return to_torch(img), 0 + + # Generate gaussian + size = 6 * sigma + 1 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + if type == 'Gaussian': + g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + elif type == 'Cauchy': + g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], img.shape[1]) + img_y = max(0, ul[1]), min(br[1], img.shape[0]) + + img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] + + return to_torch(img), 1 + + + +def draw_labelmap(img, pt, sigma, type='Gaussian'): + # Draw a 2D gaussian + # real probability distribution: the sum of all values is 1 + img = to_numpy(img) + if not type == 'Gaussian': + raise NotImplementedError + + # Check that any part of the gaussian is in-bounds + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] + if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or + br[0] < 0 or br[1] < 0): + # If not, just return the image as is + return to_torch(img), 0 + + # Generate gaussian + # img_new = dsnt.render_gaussian2d(mean=torch.tensor([[-1, 0]]).float(), std=torch.tensor([[sigma, sigma]]).float(), size=(img.shape[0], img.shape[1]), normalized_coordinates=False) + img_new = dsnt.render_gaussian2d(mean=torch.tensor([[pt[0], pt[1]]]).float(), \ + std=torch.tensor([[sigma, sigma]]).float(), \ + size=(img.shape[0], img.shape[1]), \ + normalized_coordinates=False) + img_new = img_new[0, :, :] # this is a torch image + return img_new, 1 + + +def draw_multiple_labelmaps(out_res, pts, sigma, type='Gaussian'): + # Draw a 2D gaussian + # real probability distribution: the sum of all values is 1 + if not type == 'Gaussian': + raise NotImplementedError + + # Generate gaussians + n_pts = pts.shape[0] + imgs_new = dsnt.render_gaussian2d(mean=pts[:, :2], \ + std=torch.tensor([[sigma, sigma]]).float().repeat((n_pts, 1)), \ + size=(out_res[0], out_res[1]), \ + normalized_coordinates=False) # shape: (n_pts, out_res[0], out_res[1]) + + visibility_orig = imgs_new.sum(axis=2).sum(axis=1) # shape: (n_pts) + visibility = torch.zeros((n_pts, 1), dtype=torch.float32) + visibility[visibility_orig>=0.99999] = 1.0 + + # import pdb; pdb.set_trace() + + return imgs_new, visibility.int() \ No newline at end of file diff --git a/src/stacked_hourglass/utils/logger.py b/src/stacked_hourglass/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..8e42823a88ae20117fc5aa191f569491c102b1f3 --- /dev/null +++ b/src/stacked_hourglass/utils/logger.py @@ -0,0 +1,73 @@ + +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import numpy as np + +__all__ = ['Logger'] + + +class Logger: + """Log training metrics to a file.""" + def __init__(self, fpath, resume=False): + if resume: ############################################################################ + # Read header names and previously logged values. + with open(fpath, 'r') as f: + header_line = f.readline() + self.names = header_line.rstrip().split('\t') + self.numbers = {} + for _, name in enumerate(self.names): + self.numbers[name] = [] + for numbers in f: + numbers = numbers.rstrip().split('\t') + for i in range(0, len(numbers)): + self.numbers[self.names[i]].append(float(numbers[i])) + + self.file = open(fpath, 'a') + self.header_written = True + else: + self.file = open(fpath, 'w') + self.header_written = False + + def _write_line(self, field_values): + self.file.write('\t'.join(field_values) + '\n') + self.file.flush() + + def set_names(self, names): + """Set field names and write log header line.""" + assert not self.header_written, 'Log header has already been written' + self.names = names + self.numbers = {name: [] for name in self.names} + self._write_line(self.names) + self.header_written = True + + def append(self, numbers): + """Append values to the log.""" + assert self.header_written, 'Log header has not been written yet (use `set_names`)' + assert len(self.names) == len(numbers), 'Numbers do not match names' + for index, num in enumerate(numbers): + self.numbers[self.names[index]].append(num) + self._write_line(['{0:.6f}'.format(num) for num in numbers]) + + def plot(self, ax, names=None): + """Plot logged metrics on a set of Matplotlib axes.""" + names = self.names if names == None else names + for name in names: + values = self.numbers[name] + ax.plot(np.arange(len(values)), np.asarray(values)) + ax.grid(True) + ax.legend(names, loc='best') + + def plot_to_file(self, fpath, names=None, dpi=150): + """Plot logged metrics and save the resulting figure to a file.""" + import matplotlib.pyplot as plt + fig = plt.figure(dpi=dpi) + ax = fig.subplots() + self.plot(ax, names) + fig.savefig(fpath) + plt.close(fig) + del ax, fig + + def close(self): + self.file.close() diff --git a/src/stacked_hourglass/utils/misc.py b/src/stacked_hourglass/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d754c55dc2206bbb2a5cabf18c4017b5c1ee3d04 --- /dev/null +++ b/src/stacked_hourglass/utils/misc.py @@ -0,0 +1,56 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import os +import shutil + +import scipy.io +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.detach().cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array" + .format(type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor" + .format(type(ndarray))) + return ndarray + + +def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None): + preds = to_numpy(preds) + filepath = os.path.join(checkpoint, filename) + torch.save(state, filepath) + scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds}) + + if snapshot and state['epoch'] % snapshot == 0: + shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch']))) + + if is_best: + shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) + scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds}) + + +def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): + preds = to_numpy(preds) + filepath = os.path.join(checkpoint, filename) + scipy.io.savemat(filepath, mdict={'preds' : preds}) + + +def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): + """Sets the learning rate to the initial LR decayed by schedule""" + if epoch in schedule: + lr *= gamma + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr diff --git a/src/stacked_hourglass/utils/pilutil.py b/src/stacked_hourglass/utils/pilutil.py new file mode 100644 index 0000000000000000000000000000000000000000..4306a31e76581cf9a7dd9901b88be1a2df2a75f0 --- /dev/null +++ b/src/stacked_hourglass/utils/pilutil.py @@ -0,0 +1,509 @@ +""" +A collection of image utilities using the Python Imaging Library (PIL). +""" + +# Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import division, print_function, absolute_import + +import numpy +from PIL import Image +from numpy import (amin, amax, ravel, asarray, arange, ones, newaxis, + transpose, iscomplexobj, uint8, issubdtype, array) + +if not hasattr(Image, 'frombytes'): + Image.frombytes = Image.fromstring + +__all__ = ['fromimage', 'toimage', 'imsave', 'imread', 'bytescale', + 'imrotate', 'imresize'] + + +def bytescale(data, cmin=None, cmax=None, high=255, low=0): + """ + Byte scales an array (image). + + Byte scaling means converting the input image to uint8 dtype and scaling + the range to ``(low, high)`` (default 0-255). + If the input image already has dtype uint8, no scaling is done. + + This function is only available if Python Imaging Library (PIL) is installed. + + Parameters + ---------- + data : ndarray + PIL image data array. + cmin : scalar, optional + Bias scaling of small values. Default is ``data.min()``. + cmax : scalar, optional + Bias scaling of large values. Default is ``data.max()``. + high : scalar, optional + Scale max value to `high`. Default is 255. + low : scalar, optional + Scale min value to `low`. Default is 0. + + Returns + ------- + img_array : uint8 ndarray + The byte-scaled array. + + Examples + -------- + >>> img = numpy.array([[ 91.06794177, 3.39058326, 84.4221549 ], + ... [ 73.88003259, 80.91433048, 4.88878881], + ... [ 51.53875334, 34.45808177, 27.5873488 ]]) + >>> bytescale(img) + array([[255, 0, 236], + [205, 225, 4], + [140, 90, 70]], dtype=uint8) + >>> bytescale(img, high=200, low=100) + array([[200, 100, 192], + [180, 188, 102], + [155, 135, 128]], dtype=uint8) + >>> bytescale(img, cmin=0, cmax=255) + array([[91, 3, 84], + [74, 81, 5], + [52, 34, 28]], dtype=uint8) + + """ + if data.dtype == uint8: + return data + + if high > 255: + raise ValueError("`high` should be less than or equal to 255.") + if low < 0: + raise ValueError("`low` should be greater than or equal to 0.") + if high < low: + raise ValueError("`high` should be greater than or equal to `low`.") + + if cmin is None: + cmin = data.min() + if cmax is None: + cmax = data.max() + + cscale = cmax - cmin + if cscale < 0: + raise ValueError("`cmax` should be larger than `cmin`.") + elif cscale == 0: + cscale = 1 + + scale = float(high - low) / cscale + bytedata = (data - cmin) * scale + low + return (bytedata.clip(low, high) + 0.5).astype(uint8) + + +def imread(name, flatten=False, mode=None): + """ + Read an image from a file as an array. + + This function is only available if Python Imaging Library (PIL) is installed. + + Parameters + ---------- + name : str or file object + The file name or file object to be read. + flatten : bool, optional + If True, flattens the color layers into a single gray-scale layer. + mode : str, optional + Mode to convert image to, e.g. ``'RGB'``. See the Notes for more + details. + + Returns + ------- + imread : ndarray + The array obtained by reading the image. + + Notes + ----- + `imread` uses the Python Imaging Library (PIL) to read an image. + The following notes are from the PIL documentation. + + `mode` can be one of the following strings: + + * 'L' (8-bit pixels, black and white) + * 'P' (8-bit pixels, mapped to any other mode using a color palette) + * 'RGB' (3x8-bit pixels, true color) + * 'RGBA' (4x8-bit pixels, true color with transparency mask) + * 'CMYK' (4x8-bit pixels, color separation) + * 'YCbCr' (3x8-bit pixels, color video format) + * 'I' (32-bit signed integer pixels) + * 'F' (32-bit floating point pixels) + + PIL also provides limited support for a few special modes, including + 'LA' ('L' with alpha), 'RGBX' (true color with padding) and 'RGBa' + (true color with premultiplied alpha). + + When translating a color image to black and white (mode 'L', 'I' or + 'F'), the library uses the ITU-R 601-2 luma transform:: + + L = R * 299/1000 + G * 587/1000 + B * 114/1000 + + When `flatten` is True, the image is converted using mode 'F'. + When `mode` is not None and `flatten` is True, the image is first + converted according to `mode`, and the result is then flattened using + mode 'F'. + + """ + + im = Image.open(name) + return fromimage(im, flatten=flatten, mode=mode) + + +def imsave(name, arr, format=None): + """ + Save an array as an image. + + This function is only available if Python Imaging Library (PIL) is installed. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Parameters + ---------- + name : str or file object + Output file name or file object. + arr : ndarray, MxN or MxNx3 or MxNx4 + Array containing image values. If the shape is ``MxN``, the array + represents a grey-level image. Shape ``MxNx3`` stores the red, green + and blue bands along the last dimension. An alpha layer may be + included, specified as the last colour band of an ``MxNx4`` array. + format : str + Image format. If omitted, the format to use is determined from the + file name extension. If a file object was used instead of a file name, + this parameter should always be used. + + Examples + -------- + Construct an array of gradient intensity values and save to file: + + >>> x = numpy.zeros((255, 255), dtype=numpy.uint8) + >>> x[:] = numpy.arange(255) + >>> imsave('gradient.png', x) + + Construct an array with three colour bands (R, G, B) and store to file: + + >>> rgb = numpy.zeros((255, 255, 3), dtype=numpy.uint8) + >>> rgb[..., 0] = numpy.arange(255) + >>> rgb[..., 1] = 55 + >>> rgb[..., 2] = 1 - numpy.arange(255) + >>> imsave('rgb_gradient.png', rgb) + + """ + im = toimage(arr, channel_axis=2) + if format is None: + im.save(name) + else: + im.save(name, format) + return + + +def fromimage(im, flatten=False, mode=None): + """ + Return a copy of a PIL image as a numpy array. + + This function is only available if Python Imaging Library (PIL) is installed. + + Parameters + ---------- + im : PIL image + Input image. + flatten : bool + If true, convert the output to grey-scale. + mode : str, optional + Mode to convert image to, e.g. ``'RGB'``. See the Notes of the + `imread` docstring for more details. + + Returns + ------- + fromimage : ndarray + The different colour bands/channels are stored in the + third dimension, such that a grey-image is MxN, an + RGB-image MxNx3 and an RGBA-image MxNx4. + + """ + if not Image.isImageType(im): + raise TypeError("Input is not a PIL image.") + + if mode is not None: + if mode != im.mode: + im = im.convert(mode) + elif im.mode == 'P': + # Mode 'P' means there is an indexed "palette". If we leave the mode + # as 'P', then when we do `a = array(im)` below, `a` will be a 2-D + # containing the indices into the palette, and not a 3-D array + # containing the RGB or RGBA values. + if 'transparency' in im.info: + im = im.convert('RGBA') + else: + im = im.convert('RGB') + + if flatten: + im = im.convert('F') + elif im.mode == '1': + # Workaround for crash in PIL. When im is 1-bit, the call array(im) + # can cause a seg. fault, or generate garbage. See + # https://github.com/scipy/scipy/issues/2138 and + # https://github.com/python-pillow/Pillow/issues/350. + # + # This converts im from a 1-bit image to an 8-bit image. + im = im.convert('L') + + a = array(im) + return a + + +_errstr = "Mode is unknown or incompatible with input array shape." + + +def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None, + mode=None, channel_axis=None): + """Takes a numpy array and returns a PIL image. + + This function is only available if Python Imaging Library (PIL) is installed. + + The mode of the PIL image depends on the array shape and the `pal` and + `mode` keywords. + + For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values + (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode + is given as 'F' or 'I' in which case a float and/or integer array is made. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Notes + ----- + For 3-D arrays, the `channel_axis` argument tells which dimension of the + array holds the channel data. + + For 3-D arrays if one of the dimensions is 3, the mode is 'RGB' + by default or 'YCbCr' if selected. + + The numpy array must be either 2 dimensional or 3 dimensional. + + """ + data = asarray(arr) + if iscomplexobj(data): + raise ValueError("Cannot convert a complex-valued array.") + shape = list(data.shape) + valid = len(shape) == 2 or ((len(shape) == 3) and + ((3 in shape) or (4 in shape))) + if not valid: + raise ValueError("'arr' does not have a suitable array shape for " + "any mode.") + if len(shape) == 2: + shape = (shape[1], shape[0]) # columns show up first + if mode == 'F': + data32 = data.astype(numpy.float32) + image = Image.frombytes(mode, shape, data32.tostring()) + return image + if mode in [None, 'L', 'P']: + bytedata = bytescale(data, high=high, low=low, + cmin=cmin, cmax=cmax) + image = Image.frombytes('L', shape, bytedata.tostring()) + if pal is not None: + image.putpalette(asarray(pal, dtype=uint8).tostring()) + # Becomes a mode='P' automagically. + elif mode == 'P': # default gray-scale + pal = (arange(0, 256, 1, dtype=uint8)[:, newaxis] * + ones((3,), dtype=uint8)[newaxis, :]) + image.putpalette(asarray(pal, dtype=uint8).tostring()) + return image + if mode == '1': # high input gives threshold for 1 + bytedata = (data > high) + image = Image.frombytes('1', shape, bytedata.tostring()) + return image + if cmin is None: + cmin = amin(ravel(data)) + if cmax is None: + cmax = amax(ravel(data)) + data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low + if mode == 'I': + data32 = data.astype(numpy.uint32) + image = Image.frombytes(mode, shape, data32.tostring()) + else: + raise ValueError(_errstr) + return image + + # if here then 3-d array with a 3 or a 4 in the shape length. + # Check for 3 in datacube shape --- 'RGB' or 'YCbCr' + if channel_axis is None: + if (3 in shape): + ca = numpy.flatnonzero(asarray(shape) == 3)[0] + else: + ca = numpy.flatnonzero(asarray(shape) == 4) + if len(ca): + ca = ca[0] + else: + raise ValueError("Could not find channel dimension.") + else: + ca = channel_axis + + numch = shape[ca] + if numch not in [3, 4]: + raise ValueError("Channel axis dimension is not valid.") + + bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax) + if ca == 2: + strdata = bytedata.tobytes() # .tostring() + shape = (shape[1], shape[0]) + elif ca == 1: + strdata = transpose(bytedata, (0, 2, 1)).tobytes() #.tostring() + shape = (shape[2], shape[0]) + elif ca == 0: + strdata = transpose(bytedata, (1, 2, 0)).tobytes() #.tostring() + shape = (shape[2], shape[1]) + else: + raise ValueError("Unexpected channel axis.") + if mode is None: + if numch == 3: + mode = 'RGB' + else: + mode = 'RGBA' + + if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']: + raise ValueError(_errstr) + + if mode in ['RGB', 'YCbCr']: + if numch != 3: + raise ValueError("Invalid array shape for mode.") + if mode in ['RGBA', 'CMYK']: + if numch != 4: + raise ValueError("Invalid array shape for mode.") + + # Here we know data and mode is correct + image = Image.frombytes(mode, shape, strdata) + return image + + +def imrotate(arr, angle, interp='bilinear'): + """ + Rotate an image counter-clockwise by angle degrees. + + This function is only available if Python Imaging Library (PIL) is installed. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Parameters + ---------- + arr : ndarray + Input array of image to be rotated. + angle : float + The angle of rotation. + interp : str, optional + Interpolation + + - 'nearest' : for nearest neighbor + - 'bilinear' : for bilinear + - 'lanczos' : for lanczos + - 'cubic' : for bicubic + - 'bicubic' : for bicubic + + Returns + ------- + imrotate : ndarray + The rotated array of image. + + """ + arr = asarray(arr) + func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3} + im = toimage(arr) + im = im.rotate(angle, resample=func[interp]) + return fromimage(im) + + +def imresize(arr, size, interp='bilinear', mode=None): + """ + Resize an image. + + This function is only available if Python Imaging Library (PIL) is installed. + + .. warning:: + + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + + Parameters + ---------- + arr : ndarray + The array of image to be resized. + size : int, float or tuple + * int - Percentage of current size. + * float - Fraction of current size. + * tuple - Size of the output image (height, width). + + interp : str, optional + Interpolation to use for re-sizing ('nearest', 'lanczos', 'bilinear', + 'bicubic' or 'cubic'). + mode : str, optional + The PIL image mode ('P', 'L', etc.) to convert `arr` before resizing. + If ``mode=None`` (the default), 2-D images will be treated like + ``mode='L'``, i.e. casting to long integer. For 3-D and 4-D arrays, + `mode` will be set to ``'RGB'`` and ``'RGBA'`` respectively. + + Returns + ------- + imresize : ndarray + The resized array of image. + + See Also + -------- + toimage : Implicitly used to convert `arr` according to `mode`. + scipy.ndimage.zoom : More generic implementation that does not use PIL. + + """ + im = toimage(arr, mode=mode) + ts = type(size) + if issubdtype(ts, numpy.signedinteger): + percent = size / 100.0 + size = tuple((array(im.size)*percent).astype(int)) + elif issubdtype(type(size), numpy.floating): + size = tuple((array(im.size)*size).astype(int)) + else: + size = (size[1], size[0]) + func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3} + imnew = im.resize(size, resample=func[interp]) + return fromimage(imnew) diff --git a/src/stacked_hourglass/utils/transforms.py b/src/stacked_hourglass/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7777e02f7a78e282c9032cb76325bafbbb16a5be --- /dev/null +++ b/src/stacked_hourglass/utils/transforms.py @@ -0,0 +1,150 @@ +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import numpy as np +import torch + +from .imutils import im_to_numpy, im_to_torch +from .misc import to_torch +from .pilutil import imresize, imrotate + + +def color_normalize(x, mean, std): + if x.size(0) == 1: + x = x.repeat(3, 1, 1) + + for t, m, s in zip(x, mean, std): + t.sub_(m) + return x + + +def flip_back(flip_output, hflip_indices): + """flip and rearrange output maps""" + return fliplr(flip_output)[:, hflip_indices] + + +def shufflelr(x, width, hflip_indices): + """flip and rearrange coords""" + # Flip horizontal + x[:, 0] = width - x[:, 0] + # Change left-right parts + x = x[hflip_indices] + return x + + +def fliplr(x): + """Flip images horizontally.""" + if torch.is_tensor(x): + return torch.flip(x, [-1]) + else: + return np.ascontiguousarray(np.flip(x, -1)) + + +def get_transform(center, scale, res, rot=0): + """ + General image processing functions + """ + # Generate transformation matrix + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3,3)) + rot_rad = rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + rot_mat[2,2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0,2] = -res[1]/2 + t_mat[1,2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2,2] *= -1 + t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0, as_int=True): + # Transform pixel location to different reference + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + if as_int: + return new_pt[:2].astype(int) + 1 + else: + return new_pt[:2] + 1 + + + +def transform_preds(coords, center, scale, res): + # size = coords.size() + # coords = coords.view(-1, coords.size(-1)) + # print(coords.size()) + for p in range(coords.size(0)): + coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0)) + return coords + + +def crop(img, center, scale, res, rot=0, interp='bilinear'): + # import pdb; pdb.set_trace() + # mode = 'F' + + img = im_to_numpy(img) + + # Preprocessing for efficient cropping + ht, wd = img.shape[0], img.shape[1] + sf = scale * 200.0 / res[0] + if sf < 2: + sf = 1 + else: + new_size = int(np.math.floor(max(ht, wd) / sf)) + new_ht = int(np.math.floor(ht / sf)) + new_wd = int(np.math.floor(wd / sf)) + if new_size < 2: + return torch.zeros(res[0], res[1], img.shape[2]) \ + if len(img.shape) > 2 else torch.zeros(res[0], res[1]) + else: + img = imresize(img, [new_ht, new_wd], interp=interp) # , mode=mode) + center = center * 1.0 / sf + scale = scale / sf + + # Upper left point + ul = np.array(transform([0, 0], center, scale, res, invert=1)) + # Bottom right point + br = np.array(transform(res, center, scale, res, invert=1)) + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(img.shape[1], br[0]) + old_y = max(0, ul[1]), min(img.shape[0], br[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] + + if not rot == 0: + # Remove padding + new_img = imrotate(new_img, rot, interp=interp) # , mode=mode) + new_img = new_img[pad:-pad, pad:-pad] + + new_img = im_to_torch(imresize(new_img, res, interp=interp)) #, mode=mode)) + return new_img diff --git a/src/stacked_hourglass/utils/visualization.py b/src/stacked_hourglass/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..4487e7f10c348af91b3958081f6f029308440772 --- /dev/null +++ b/src/stacked_hourglass/utils/visualization.py @@ -0,0 +1,179 @@ + +# Modified from: +# https://github.com/anibali/pytorch-stacked-hourglass +# https://github.com/bearpaw/pytorch-pose + +import matplotlib as mpl +mpl.use('Agg') +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import torch + +# import stacked_hourglass.datasets.utils_stanext as utils_stanext +# COLORS, labels = utils_stanext.load_keypoint_labels_and_colours() +COLORS = ['#d82400', '#d82400', '#d82400', '#fcfc00', '#fcfc00', '#fcfc00', '#48b455', '#48b455', '#48b455', '#0090aa', '#0090aa', '#0090aa', '#d848ff', '#d848ff', '#fc90aa', '#006caa', '#d89000', '#d89000', '#fc90aa', '#006caa', '#ededed', '#ededed', '#a9d08e', '#a9d08e'] +RGB_MEAN = [0.4404, 0.4440, 0.4327] +RGB_STD = [0.2458, 0.2410, 0.2468] + + + +def get_img_from_fig(fig, dpi=180): + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=dpi) + buf.seek(0) + img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) + buf.close() + img = cv2.imdecode(img_arr, 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + +def save_input_image_with_keypoints(img, tpts, out_path='./test_input.png', colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD, ratio_in_out=4., threshold=0.3, print_scores=False): + """ + img has shape (3, 256, 256) and is a torch tensor + pts has shape (20, 3) and is a torch tensor + -> this function is tested with the mpii dataset and the results look ok + """ + # reverse color normalization + for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() + img_np = img.detach().cpu().numpy().transpose(1, 2, 0) + # tpts_np = tpts.detach().cpu().numpy() + # plot image + fig, ax = plt.subplots() + plt.imshow(img_np) # plt.imshow(im) + plt.gca().set_axis_off() + plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) + plt.margins(0,0) + # plot all visible keypoints + #import pdb; pdb.set_trace() + + for idx, (x, y, v) in enumerate(tpts): + if v > threshold: + x = int(x*ratio_in_out) + y = int(y*ratio_in_out) + plt.scatter([x], [y], c=[colors[idx]], marker="x", s=50) + if print_scores: + txt = '{:2.2f}'.format(v.item()) + plt.annotate(txt, (x, y)) # , c=colors[idx]) + + plt.savefig(out_path, bbox_inches='tight', pad_inches=0) + + plt.close() + return + + + +def save_input_image(img, out_path, colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD): + for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() + img_np = img.detach().cpu().numpy().transpose(1, 2, 0) + plt.imsave(out_path, img_np) + return + +###################################################################### +def get_bodypart_colors(): + # body colors + n_body = 8 + c = np.arange(1, n_body + 1) + norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max()) + cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.gist_rainbow) + cmap.set_array([]) + body_cols = [] + for i in range(0, n_body): + body_cols.append(cmap.to_rgba(i + 1)) + # head colors + n_blue = 5 + c = np.arange(1, n_blue + 1) + norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) + cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues) + cmap.set_array([]) + head_cols = [] + for i in range(0, n_body): + head_cols.append(cmap.to_rgba(i + 1)) + # torso colors + n_blue = 2 + c = np.arange(1, n_blue + 1) + norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) + cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Greens) + cmap.set_array([]) + torso_cols = [] + for i in range(0, n_body): + torso_cols.append(cmap.to_rgba(i + 1)) + return body_cols, head_cols, torso_cols +body_cols, head_cols, torso_cols = get_bodypart_colors() +tbp_dict = {'full_body': [0, 8], + 'head': [8, 13], + 'torso': [13, 15]} + +def save_image_with_part_segmentation(partseg_big, seg_big, input_image_np, ind_img, out_path_seg=None, out_path_seg_overlay=None, thr=0.3): + soft_max = torch.nn.Softmax(dim=0) + # create dit with results + tbp_dict_res = {} + for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): + partseg_tbp = partseg_big[:, tbp_dict[part][0]:tbp_dict[part][1], :, :] + segm_img_pred = soft_max((partseg_tbp[ind_img, :, :, :])) # [1, :, :] + m_v, m_i = segm_img_pred.max(axis=0) + tbp_dict_res[part] = { + 'inds': tbp_dict[part], + 'seg_probs': segm_img_pred, + 'seg_max_inds': m_i, + 'seg_max_values': m_v} + # create output_image + partseg_image = np.zeros((256, 256, 3)) + for ind_sp in range(0, 5): + # partseg_image[tbp_dict_res['head']['seg_max_inds']==ind_sp, :] = head_cols[ind_sp][0:3] + mask_a = tbp_dict_res['full_body']['seg_max_inds']==1 + mask_b = tbp_dict_res['head']['seg_max_inds']==ind_sp + partseg_image[mask_a*mask_b, :] = head_cols[ind_sp][0:3] + for ind_sp in range(0, 2): + # partseg_image[tbp_dict_res['torso']['seg_max_inds']==ind_sp, :] = torso_cols[ind_sp][0:3] + mask_a = tbp_dict_res['full_body']['seg_max_inds']==2 + mask_b = tbp_dict_res['torso']['seg_max_inds']==ind_sp + partseg_image[mask_a*mask_b, :] = torso_cols[ind_sp][0:3] + for ind_sp in range(0, 8): + if (not ind_sp == 1) and (not ind_sp == 2): # head and torso + partseg_image[tbp_dict_res['full_body']['seg_max_inds']==ind_sp, :] = body_cols[ind_sp][0:3] + partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]aik', remeshing_relevant_barys, sel_verts) + target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32)) + target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) + ''' + res['isflat_prep'] = soft_max(output_ref_unnorm['isflat'])[:, 1] + + + else: + # STEP 1: next loop in refinemnet network + assert (output_ref_unnorm_new is not None) + res['vertices_smal'] = output_ref_unnorm_new['vertices_smal'] + res['flength'] = output_ref_unnorm_new['flength'] + res['pose_rotmat'] = output_ref_unnorm_new['pose_rotmat'] + res['trans'] = output_ref_unnorm_new['trans'] + res['pred_keyp'] = output_ref_unnorm_new['keyp_2d'] + res['pred_silh'] = output_ref_unnorm_new['silh'] + res['prefix'] = 'multref_' + if 'vertexwise_ground_contact' in output_ref_unnorm_new.keys(): + res['vertexwise_ground_contact'] = output_ref_unnorm_new['vertexwise_ground_contact'] + all_sum_res[result_network] = res + return all_sum_res + + +class BITEInferenceModel(): #(nn.Module): + def __init__(self, cfg, path_model_file_complete, norm_dict, device='cuda'): + # def __init__(self, bp, model_weight_path=None, model_weight_stackedhg_path=None, device='cuda'): + # self.bp = bp + self.cfg = cfg + self.device = device + self.norm_dict = norm_dict + + # prepare complete model + self.complete_model = ModelImageTo3d_withshape_withproj( + smal_model_type=cfg.smal.SMAL_MODEL_TYPE, smal_keyp_conf=cfg.smal.SMAL_KEYP_CONF, \ + num_stage_comb=cfg.params.NUM_STAGE_COMB, num_stage_heads=cfg.params.NUM_STAGE_HEADS, \ + num_stage_heads_pose=cfg.params.NUM_STAGE_HEADS_POSE, trans_sep=cfg.params.TRANS_SEP, \ + arch=cfg.params.ARCH, n_joints=cfg.params.N_JOINTS, n_classes=cfg.params.N_CLASSES, \ + n_keyp=cfg.params.N_KEYP, n_bones=cfg.params.N_BONES, n_betas=cfg.params.N_BETAS, n_betas_limbs=cfg.params.N_BETAS_LIMBS, \ + n_breeds=cfg.params.N_BREEDS, n_z=cfg.params.N_Z, image_size=cfg.params.IMG_SIZE, \ + silh_no_tail=cfg.params.SILH_NO_TAIL, thr_keyp_sc=cfg.params.KP_THRESHOLD, add_z_to_3d_input=cfg.params.ADD_Z_TO_3D_INPUT, + n_segbps=cfg.params.N_SEGBPS, add_segbps_to_3d_input=cfg.params.ADD_SEGBPS_TO_3D_INPUT, add_partseg=cfg.params.ADD_PARTSEG, n_partseg=cfg.params.N_PARTSEG, \ + fix_flength=cfg.params.FIX_FLENGTH, structure_z_to_betas=cfg.params.STRUCTURE_Z_TO_B, structure_pose_net=cfg.params.STRUCTURE_POSE_NET, + nf_version=cfg.params.NF_VERSION, ref_net_type=cfg.params.REF_NET_TYPE, graphcnn_type=cfg.params.GRAPHCNN_TYPE, isflat_type=cfg.params.ISFLAT_TYPE, shaperef_type=cfg.params.SHAPEREF_TYPE) + + # load trained model + print(path_model_file_complete) + assert os.path.isfile(path_model_file_complete) + print('Loading model weights from file: {}'.format(path_model_file_complete)) + checkpoint_complete = torch.load(path_model_file_complete) + state_dict_complete = checkpoint_complete['state_dict'] + self.complete_model.load_state_dict(state_dict_complete) # , strict=False) + self.complete_model = self.complete_model.to(self.device) + self.complete_model.eval() + + self.smal_model_type = self.complete_model.smal.smal_model_type + + def get_selected_results(self, preds_dict=None, input_img_prep=None, result_networks=['ref']): + assert ((preds_dict is not None) or (input_img_prep is not None)) + if preds_dict is None: + preds_dict = self.get_all_results(input_img_prep) + all_sum_res = get_summarized_bite_result(preds_dict['output'], preds_dict['output_unnorm'], preds_dict['output_reproj'], preds_dict['output_ref_unnorm'], preds_dict['output_orig_ref_comparison'], result_networks=result_networks) + return all_sum_res + + def get_selected_results_multiple_refinements(self, preds_dict=None, input_img_prep=None, result_networks=['multref']): + assert ((preds_dict is not None) or (input_img_prep is not None)) + if preds_dict is None: + preds_dict = self.get_all_results_multiple_refinements(input_img_prep) + all_sum_res = get_summarized_bite_result(preds_dict['output'], preds_dict['output_unnorm'], preds_dict['output_reproj'], preds_dict['output_ref_unnorm'], preds_dict['output_orig_ref_comparison'], preds_dict['output_ref_unnorm_new'], preds_dict['output_orig_ref_comparison_new'], result_networks=result_networks) + return all_sum_res + + + def get_all_results(self, input_img_prep): + output, output_unnorm, output_reproj, output_ref, output_ref_comp = self.complete_model(input_img_prep, norm_dict=self.norm_dict) + preds_dict = {'output': output, + 'output_unnorm': output_unnorm, + 'output_reproj': output_reproj, + 'output_ref_unnorm': output_ref, + 'output_orig_ref_comparison': output_ref_comp + } + return preds_dict + + + def get_all_results_multiple_refinements(self, input_img_prep): + preds_dict = self.complete_model.forward_with_multiple_refinements(input_img_prep, norm_dict=self.norm_dict) + # output, output_unnorm, output_reproj, output_ref, output_ref_comp, output_ref_unnorm_new, output_orig_ref_comparison_new + return preds_dict + + + + + + + + + + + + + + + + + + + + diff --git a/src/test_time_optimization/evaluate_ttopt.py b/src/test_time_optimization/evaluate_ttopt.py new file mode 100644 index 0000000000000000000000000000000000000000..114a3efde5bd59acbdfaaf8a8cc60cd38f8cd6d4 --- /dev/null +++ b/src/test_time_optimization/evaluate_ttopt.py @@ -0,0 +1,368 @@ + +# evaluate test time optimization from refinement +# python src/test_time_optimization/evaluate_ttopt.py --workers 12 --save-images True --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --ttopt-result-name ttoptv6_stanext_v16b + +# python src/test_time_optimization/evaluate_ttopt.py --workers 12 --save-images True --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --ttopt-result-name ttoptv6_stanext_v16 + + + +import argparse +import os.path +import json +import numpy as np +import pickle as pkl +from distutils.util import strtobool +import torch +from torch import nn +import torch.backends.cudnn +from torch.nn import DataParallel +from torch.utils.data import DataLoader +import pytorch3d as p3d +from collections import OrderedDict +import glob +from tqdm import tqdm +from dominate import document +from dominate.tags import * +from PIL import Image +from matplotlib import pyplot as plt +import trimesh +import cv2 +import shutil + +from pytorch3d.structures import Meshes +from pytorch3d.loss import mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) + +from combined_model.train_main_image_to_3d_wbr_withref import do_validation_epoch +# from combined_model.model_shape_v7 import ModelImageTo3d_withshape_withproj +# from combined_model.model_shape_v7_withref import ModelImageTo3d_withshape_withproj +from combined_model.model_shape_v7_withref_withgraphcnn import ModelImageTo3d_withshape_withproj + +from combined_model.loss_image_to_3d_withbreedrel import Loss +from combined_model.loss_image_to_3d_refinement import LossRef +from configs.barc_cfg_defaults import get_cfg_defaults, update_cfg_global_with_yaml, get_cfg_global_updated + +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d # , batch_rot2aa, geodesic_loss_R + + +# from test_time_optimization.utils_ttopt import get_evaluation_dataset, get_norm_dict +from stacked_hourglass.datasets.utils_dataset_selection import get_evaluation_dataset, get_norm_dict + +from test_time_optimization.bite_inference_model_for_ttopt import BITEInferenceModel +from smal_pytorch.smal_model.smal_torch_new import SMAL +from configs.SMAL_configs import SMAL_MODEL_CONFIG +from smal_pytorch.renderer.differentiable_renderer import SilhRenderer +from test_time_optimization.utils.utils_ttopt import reset_loss_values, get_optimed_pose_with_glob + +from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error +from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch +from combined_model.loss_utils.loss_arap import Arap_Loss +from combined_model.loss_utils.loss_laplacian_mesh_comparison import LaplacianCTF # (coarse to fine animal) +from graph_networks import graphcmr # .utils_mesh import Mesh +from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image + +from metrics.metrics import Metrics +from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS + + +ROOT_LOSS_WEIGH_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/configs/ttopt_loss_weights/' + + + +def main(args): + + # load configs + # step 1: load default configs + # step 2: load updates from .yaml file + path_config = os.path.join(get_cfg_defaults().barc_dir, 'src', 'configs', args.config) + update_cfg_global_with_yaml(path_config) + cfg = get_cfg_global_updated() + + pck_thresh = 0.15 + print('pck_thresh: ' + str(pck_thresh)) + + + + + ROOT_IN_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/results/results_ttopt/' + args.ttopt_result_name + '/' # ttoptv6_debug_x8/' + ROOT_IN_PATH_DETAIL = ROOT_IN_PATH + 'details/' + + ROOT_OUT_PATH = ROOT_IN_PATH + 'evaluation/' + if not os.path.exists(ROOT_OUT_PATH): os.makedirs(ROOT_OUT_PATH) + + + + + + + + + + + # NEW!!! + logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] + # logscale_part_list = ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'] + + + # Select the hardware device to use for training. + if torch.cuda.is_available() and cfg.device=='cuda': + device = torch.device('cuda', torch.cuda.current_device()) + torch.backends.cudnn.benchmark = False # True + else: + device = torch.device('cpu') + + print('structure_pose_net: ' + cfg.params.STRUCTURE_POSE_NET) + print('refinement network type: ' + cfg.params.REF_NET_TYPE) + print('smal_model_type: ' + cfg.smal.SMAL_MODEL_TYPE) + + path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.model_file_complete) + + # Disable gradient calculations. + # torch.set_grad_enabled(False) + + + # prepare dataset and dataset loadr + val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_evaluation_dataset(cfg.data.DATASET, cfg.data.VAL_OPT, cfg.data.V12, cfg.optim.BATCH_SIZE, args.workers) + len_data = len_val_dataset + # summarize information for normalization + norm_dict = get_norm_dict(stanext_data_info, device) + + # prepare complete model + bite_model = BITEInferenceModel(cfg, path_model_file_complete, norm_dict) + # smal_model_type = bite_model.complete_model.smal.smal_model_type + smal_model_type = bite_model.smal_model_type + smal = SMAL(smal_model_type=smal_model_type, template_name='neutral', logscale_part_list=logscale_part_list).to(device) + silh_renderer = SilhRenderer(image_size=256).to(device) + + + + # ---------------------------------------------------------------------------------- + + summary = {} + summary['pck'] = np.zeros((len_data)) + summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS} + summary['acc_sil_2d'] = np.zeros(len_data) + + + + + + + + + + + + # Put the model in training mode. + # model.train() + # prepare progress bar + iterable = enumerate(val_loader) + progress = None + if True: # not quiet: + progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False) + iterable = progress + ind_img_tot = 0 + # prepare variables, put them on the right device + + my_step = 0 + batch_size = cfg.optim.BATCH_SIZE + + for index, (input, target_dict) in iterable: + for key in target_dict.keys(): + if key == 'breed_index': + target_dict[key] = target_dict[key].long().to(device) + elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']: + target_dict[key] = target_dict[key].float().to(device) + elif key == 'has_seg': + target_dict[key] = target_dict[key].to(device) + else: + pass + input = input.float().to(device) + + + + # get starting values for the optimization + # -> here from barc, but could also be saved and loaded + preds_dict = bite_model.get_all_results(input) + res_normal_and_ref = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['normal', 'ref']) + res = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['ref'])['ref'] + + # -------------------------------------------------------------------- + + # ind_img = 0 + + batch_verts_smal = [] + batch_faces_prep = [] + batch_optimed_camera_flength = [] + + + + for ind_img in range(input.shape[0]): + name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0] + + print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name) + ind_img_tot += 1 + + e_name = 'e000' # 'e300' + + npy_file = ROOT_IN_PATH_DETAIL + name + '_flength_' + e_name +'.npy' + flength = np.load(npy_file) + optimed_camera_flength = torch.tensor(flength, device=device) + + obj_file = ROOT_IN_PATH + name + '_res_' + e_name +'.obj' + + verts, faces, aux = p3d.io.load_obj(obj_file) + verts_smal = verts[None, ...].to(device) + faces_prep = faces.verts_idx[None, ...].to(device) + batch_verts_smal.append(verts_smal) + batch_faces_prep.append(faces_prep) + batch_optimed_camera_flength.append(optimed_camera_flength) + + + # import pdb; pdb.set_trace() + + verts_smal = torch.cat(batch_verts_smal, dim=0) + faces_prep = torch.cat(batch_faces_prep, dim=0) + optimed_camera_flength = torch.cat(batch_optimed_camera_flength, dim=0) + + # get keypoint locations from mesh vertices + keyp_3d = smal.get_joints_from_verts(verts_smal, keyp_conf='olive') + + + # render silhouette and keypoints + pred_silh_images, pred_keyp_raw = silh_renderer(vertices=verts_smal, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength) + pred_keyp = pred_keyp_raw[:, :24, :] + + + + # --------------- calculate iou and pck values -------------------- + + gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1) + gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) + # prepare silhouette for IoU calculation - predicted as well as ground truth + has_seg = target_dict['has_seg'] + img_border_mask = target_dict['img_border_mask'][:, 0, :, :] + gtseg = target_dict['silh'] + synth_silhouettes = pred_silh_images[:, 0, :, :] # pred_silh[:, 0, :, :] # output_reproj['silh'] + synth_silhouettes[synth_silhouettes>0.5] = 1 + synth_silhouettes[synth_silhouettes<0.5] = 0 + # calculate PCK as well as IoU (similar to WLDO) + preds = {} + preds['acc_PCK'] = Metrics.PCK( + pred_keyp, gt_keypoints, + gtseg, has_seg, idxs=EVAL_KEYPOINTS, + thresh_range=[pck_thresh], # [0.15], + ) + preds['acc_IOU'] = Metrics.IOU( + synth_silhouettes, gtseg, + img_border_mask, mask=has_seg + ) + for group, group_kps in KEYPOINT_GROUPS.items(): + preds[f'{group}_PCK'] = Metrics.PCK( + pred_keyp, gt_keypoints, gtseg, has_seg, + thresh_range=[pck_thresh], # [0.15], + idxs=group_kps + ) + + curr_batch_size = pred_keyp.shape[0] + if not (preds['acc_PCK'].data.cpu().numpy().shape == (summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size]).shape): + import pdb; pdb.set_trace() + summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() + summary['acc_sil_2d'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() + for part in summary['pck_by_part']: + summary['pck_by_part'][part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() + + + + + my_step += 1 + + + + + + # import pdb; pdb.set_trace() + + + + + + + iou = np.nanmean(summary['acc_sil_2d']) + pck = np.nanmean(summary['pck']) + pck_legs = np.nanmean(summary['pck_by_part']['legs']) + pck_tail = np.nanmean(summary['pck_by_part']['tail']) + pck_ears = np.nanmean(summary['pck_by_part']['ears']) + pck_face = np.nanmean(summary['pck_by_part']['face']) + print('------------------------------------------------') + print("iou: {:.2f}".format(iou*100)) + print(' ') + print("pck: {:.2f}".format(pck*100)) + print(' ') + print("pck_legs: {:.2f}".format(pck_legs*100)) + print("pck_tail: {:.2f}".format(pck_tail*100)) + print("pck_ears: {:.2f}".format(pck_ears*100)) + print("pck_face: {:.2f}".format(pck_face*100)) + print('------------------------------------------------') + # save results in a .txt file + with open(ROOT_OUT_PATH + "a_evaluation_" + e_name + ".txt", "w") as text_file: + print("iou: {:.2f}".format(iou*100), file=text_file) + print("pck: {:.2f}".format(pck*100), file=text_file) + print("pck_legs: {:.2f}".format(pck_legs*100), file=text_file) + print("pck_tail: {:.2f}".format(pck_tail*100), file=text_file) + print("pck_ears: {:.2f}".format(pck_ears*100), file=text_file) + print("pck_face: {:.2f}".format(pck_face*100), file=text_file) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluate a stacked hourglass model.') + parser.add_argument('--model-file-complete', default='', type=str, metavar='PATH', + help='path to saved model weights') + parser.add_argument('--ttopt-result-name', default='', type=str, metavar='PATH', + help='path to saved ttopt results') + parser.add_argument('-cg', '--config', default='barc_cfg_test.yaml', type=str, metavar='PATH', + help='name of config file (default: barc_cfg_test.yaml within src/configs folder)') + parser.add_argument('--save-images', default='True', type=lambda x: bool(strtobool(x)), + help='bool indicating if images should be saved') + parser.add_argument('--workers', default=4, type=int, metavar='N', + help='number of data loading workers') + parser.add_argument('--metrics', '-m', metavar='METRICS', default='all', + choices=['all', None], + help='model architecture') + main(parser.parse_args()) diff --git a/src/test_time_optimization/my_remarks.txt b/src/test_time_optimization/my_remarks.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd93a534b6aadf69f3ff2be4060f33775d37c4b9 --- /dev/null +++ b/src/test_time_optimization/my_remarks.txt @@ -0,0 +1,8 @@ + +# this code can be used to further optimize pose and shape at test time + + + + + + diff --git a/src/test_time_optimization/utils/__init__.py b/src/test_time_optimization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/test_time_optimization/utils/utils_ttopt.py b/src/test_time_optimization/utils/utils_ttopt.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ccdb9c0e4f0802e4a76f5eaaf5d0d355fc9b8 --- /dev/null +++ b/src/test_time_optimization/utils/utils_ttopt.py @@ -0,0 +1,26 @@ + +import os +import torch + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) +from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d # , batch_rot2aa, geodesic_loss_R + + +def reset_loss_values(losses): + # losses is a dict + for key, val in losses.items(): + val['value'] = 0.0 + return losses + +def get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d): + # optimed_orient_6d: (1, 1, 6) + # optimed_pose_6d: (1, 34, 6) + bs = optimed_pose_6d.shape[0] + assert bs == 1 + optimed_pose_with_glob_6d = torch.cat((optimed_orient_6d, optimed_pose_6d), dim=1) + optimed_pose_with_glob = rot6d_to_rotmat(optimed_pose_with_glob_6d.reshape((-1, 6))).reshape((bs, -1, 3, 3)) + return optimed_pose_with_glob + + +