Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.backends.cudnn | |
import torch.nn.parallel | |
from tqdm import tqdm | |
import os | |
import pathlib | |
from matplotlib import pyplot as plt | |
import cv2 | |
import numpy as np | |
import torch | |
import trimesh | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | |
from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft | |
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image | |
from metrics.metrics import Metrics | |
from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS | |
# GOAL: have all the functions from the validation and visual epoch together | |
''' | |
save_imgs_path = ... | |
prefix = '' | |
input # this is the image | |
data_info | |
target_dict | |
render_all | |
model | |
vertices_smal = output_reproj['vertices_smal'] | |
flength = output_unnorm['flength'] | |
hg_keyp_norm = output['keypoints_norm'] | |
hg_keyp_scores = output['keypoints_scores'] | |
betas = output_reproj['betas'] | |
betas_limbs = output_reproj['betas_limbs'] | |
zz = output_reproj['z'] | |
pose_rotmat = output_unnorm['pose_rotmat'] | |
trans = output_unnorm['trans'] | |
pred_keyp = output_reproj['keyp_2d'] | |
pred_silh = output_reproj['silh'] | |
''' | |
################################################# | |
def eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=False): | |
device = input.device | |
curr_batch_size = input.shape[0] | |
# render predicted 3d models | |
visualizations = model.render_vis_nograd(vertices=vertices_smal, | |
focal_lengths=flength, | |
color=0) # color=2) | |
for ind_img in range(len(target_dict['index'])): | |
try: | |
# import pdb; pdb.set_trace() | |
if test_name_list is not None: | |
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_') | |
img_name = img_name.split('.')[0] | |
else: | |
img_name = str(index) + '_' + str(ind_img) | |
# save image with predicted keypoints | |
out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png' | |
pred_unp = (hg_keyp_norm[ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1) | |
pred_unp_maxval = hg_keyp_scores[ind_img, :, :] | |
pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1) | |
inp_img = input[ind_img, :, :, :].detach().clone() | |
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3 | |
# save predicted 3d model (front view) | |
pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 | |
pred_tex_max = np.max(pred_tex, axis=2) | |
out_path = save_imgs_path + '/' + prefix + 'tex_pred_' + img_name + '.png' | |
plt.imsave(out_path, pred_tex) | |
input_image = input[ind_img, :, :, :].detach().clone() | |
for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m) | |
input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0) | |
im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0) | |
im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :] | |
out_path = save_imgs_path + '/' + prefix + 'comp_pred_' + img_name + '.png' | |
plt.imsave(out_path, im_masked) | |
# save predicted 3d model (side view) | |
vertices_cent = vertices_smal - vertices_smal.mean(dim=1)[:, None, :] | |
roll = np.pi / 2 * torch.ones(1).float().to(device) | |
pitch = np.pi / 2 * torch.ones(1).float().to(device) | |
tensor_0 = torch.zeros(1).float().to(device) | |
tensor_1 = torch.ones(1).float().to(device) | |
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3) | |
RY = torch.stack([ | |
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]), | |
torch.stack([tensor_0, tensor_1, tensor_0]), | |
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3) | |
vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3)) | |
vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16 | |
visualizations_rot = model.render_vis_nograd(vertices=vertices_rot, | |
focal_lengths=flength, | |
color=0) # 2) | |
pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256 | |
pred_tex_max = np.max(pred_tex, axis=2) | |
out_path = save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png' | |
plt.imsave(out_path, pred_tex) | |
if render_all: | |
# save input image | |
inp_img = input[ind_img, :, :, :].detach().clone() | |
out_path = save_imgs_path + '/image_' + img_name + '.png' | |
save_input_image(inp_img, out_path) | |
# save mesh | |
V_posed = vertices_smal[ind_img, :, :].detach().cpu().numpy() | |
Faces = model.smal.f | |
mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True) | |
mesh_posed.export(save_imgs_path + '/' + prefix + 'mesh_posed_' + img_name + '.obj') | |
except: | |
print('dont save an image') | |
############ | |
def eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=None, skip_pck_and_iou=False): | |
preds = {} | |
preds['betas'] = betas.cpu().detach().numpy() | |
preds['betas_limbs'] = betas_limbs.cpu().detach().numpy() | |
preds['z'] = zz.cpu().detach().numpy() | |
preds['pose_rotmat'] = pose_rotmat.cpu().detach().numpy() | |
preds['flength'] = flength.cpu().detach().numpy() | |
preds['trans'] = trans.cpu().detach().numpy() | |
preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1)) | |
img_names = [] | |
for ind_img2 in range(0, betas.shape[0]): | |
if test_name_list is not None: | |
img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_') | |
img_name2 = img_name2.split('.')[0] | |
else: | |
img_name2 = str(index) + '_' + str(ind_img2) | |
img_names.append(img_name2) | |
preds['image_names'] = img_names | |
if not skip_pck_and_iou: | |
# prepare keypoints for PCK calculation - predicted as well as ground truth | |
# pred_keyp = output_reproj['keyp_2d'] # 256 | |
gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1) | |
# gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1 | |
gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm | |
# prepare silhouette for IoU calculation - predicted as well as ground truth | |
has_seg = target_dict['has_seg'] | |
img_border_mask = target_dict['img_border_mask'][:, 0, :, :] | |
gtseg = target_dict['silh'] | |
synth_silhouettes = pred_silh[:, 0, :, :] # output_reproj['silh'] | |
synth_silhouettes[synth_silhouettes>0.5] = 1 | |
synth_silhouettes[synth_silhouettes<0.5] = 0 | |
# calculate PCK as well as IoU (similar to WLDO) | |
preds['acc_PCK'] = Metrics.PCK( | |
pred_keyp, gt_keypoints, | |
gtseg, has_seg, idxs=EVAL_KEYPOINTS, | |
thresh_range=[pck_thresh], # [0.15], | |
) | |
preds['acc_IOU'] = Metrics.IOU( | |
synth_silhouettes, gtseg, | |
img_border_mask, mask=has_seg | |
) | |
for group, group_kps in KEYPOINT_GROUPS.items(): | |
preds[f'{group}_PCK'] = Metrics.PCK( | |
pred_keyp, gt_keypoints, gtseg, has_seg, | |
thresh_range=[pck_thresh], # [0.15], | |
idxs=group_kps | |
) | |
return preds | |
# preds['acc_PCK'] = Metrics.PCK(pred_keyp, gt_keypoints, gtseg, has_seg, idxs=EVAL_KEYPOINTS, thresh_range=[pck_thresh]) | |
# preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, gtseg, img_border_mask, mask=has_seg) | |
############################# | |
def eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=False): | |
if not skip_pck_and_iou: | |
if not (preds['acc_PCK'].data.cpu().numpy().shape == (summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size]).shape): | |
import pdb; pdb.set_trace() | |
summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy() | |
summary['acc_sil_2d'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy() | |
for part in summary['pck_by_part']: | |
summary['pck_by_part'][part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy() | |
summary['betas'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas'] | |
summary['betas_limbs'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs'] | |
summary['z'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z'] | |
summary['pose_rotmat'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat'] | |
summary['flength'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength'] | |
summary['trans'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans'] | |
summary['breed_indices'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index'] | |
summary['image_names'].extend(preds['image_names']) | |
return | |
def get_triangle_faces_from_pyvista_poly(poly): | |
"""Fetch all triangle faces.""" | |
stream = poly.faces | |
tris = [] | |
i = 0 | |
while i < len(stream): | |
n = stream[i] | |
if n != 3: | |
i += n + 1 | |
continue | |
stop = i + n + 1 | |
tris.append(stream[i+1:stop]) | |
i = stop | |
return np.array(tris) |