# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from lib.dataset.mesh_util import projection from lib.common.render import Render import numpy as np import torch import os.path as osp from torchvision.utils import make_grid from pytorch3d.io import IO from pytorch3d.ops import sample_points_from_meshes from pytorch3d.loss.point_mesh_distance import _PointFaceDistance from pytorch3d.structures import Pointclouds from PIL import Image def point_mesh_distance(meshes, pcls): if len(meshes) != len(pcls): raise ValueError("meshes and pointclouds must be equal sized batches") N = len(meshes) # packed representation for pointclouds points = pcls.points_packed() # (P, 3) points_first_idx = pcls.cloud_to_packed_first_idx() max_points = pcls.num_points_per_cloud().max().item() # packed representation for faces verts_packed = meshes.verts_packed() faces_packed = meshes.faces_packed() tris = verts_packed[faces_packed] # (T, 3, 3) tris_first_idx = meshes.mesh_to_faces_packed_first_idx() # point to face distance: shape (P,) point_to_face = _PointFaceDistance.apply(points, points_first_idx, tris, tris_first_idx, max_points, 5e-3) # weight each example by the inverse of number of points in the example point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) num_points_per_cloud = pcls.num_points_per_cloud() # (N,) weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) weights_p = 1.0 / weights_p.float() point_to_face = torch.sqrt(point_to_face) * weights_p point_dist = point_to_face.sum() / N return point_dist class Evaluator: def __init__(self, device): self.render = Render(size=512, device=device) self.device = device def set_mesh(self, result_dict): for k, v in result_dict.items(): setattr(self, k, v) self.verts_pr -= self.recon_size / 2.0 self.verts_pr /= self.recon_size / 2.0 self.verts_gt = projection(self.verts_gt, self.calib) self.verts_gt[:, 1] *= -1 self.src_mesh = self.render.VF2Mesh(self.verts_pr, self.faces_pr) self.tgt_mesh = self.render.VF2Mesh(self.verts_gt, self.faces_gt) def calculate_normal_consist(self, normal_path): self.render.meshes = self.src_mesh src_normal_imgs = self.render.get_rgb_image(cam_ids=[ 0,1,2, 3], bg='black') self.render.meshes = self.tgt_mesh tgt_normal_imgs = self.render.get_rgb_image(cam_ids=[0,1,2, 3], bg='black') src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) src_norm[src_norm == 0.0] = 1.0 tgt_norm[tgt_norm == 0.0] = 1.0 src_normal_arr /= src_norm tgt_normal_arr /= tgt_norm src_normal_arr = (src_normal_arr + 1.0) * 0.5 tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 error = (( (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 #print('normal error:', error) normal_img = Image.fromarray( (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( 1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) normal_img.save(normal_path) error_list = [] if len(src_normal_imgs) > 4: for i in range(len(src_normal_imgs)): src_normal_arr = src_normal_imgs[i] # Get each source normal image tgt_normal_arr = tgt_normal_imgs[i] # Get corresponding target normal image src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) src_norm[src_norm == 0.0] = 1.0 tgt_norm[tgt_norm == 0.0] = 1.0 src_normal_arr /= src_norm tgt_normal_arr /= tgt_norm src_normal_arr = (src_normal_arr + 1.0) * 0.5 tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 error = ((src_normal_arr - tgt_normal_arr) ** 2).sum(dim=0).mean() * 4.0 error_list.append(error) return error_list else: src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) src_norm[src_norm == 0.0] = 1.0 tgt_norm[tgt_norm == 0.0] = 1.0 src_normal_arr /= src_norm tgt_normal_arr /= tgt_norm # sim_mask = self.get_laplacian_2d(tgt_normal_arr).to(self.device) src_normal_arr = (src_normal_arr + 1.0) * 0.5 tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 error = (( (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 #print('normal error:', error) return error def export_mesh(self, dir, name): IO().save_mesh(self.src_mesh, osp.join(dir, f"{name}_src.obj")) IO().save_mesh(self.tgt_mesh, osp.join(dir, f"{name}_tgt.obj")) def calculate_chamfer_p2s(self, num_samples=1000): tgt_points = Pointclouds( sample_points_from_meshes(self.tgt_mesh, num_samples)) src_points = Pointclouds( sample_points_from_meshes(self.src_mesh, num_samples)) p2s_dist = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points) * 100.0 + p2s_dist) * 0.5 return chamfer_dist, p2s_dist def calc_acc(self, output, target, thres=0.5, use_sdf=False): # # remove the surface points with thres # non_surf_ids = (target != thres) # output = output[non_surf_ids] # target = target[non_surf_ids] with torch.no_grad(): output = output.masked_fill(output < thres, 0.0) output = output.masked_fill(output > thres, 1.0) if use_sdf: target = target.masked_fill(target < thres, 0.0) target = target.masked_fill(target > thres, 1.0) acc = output.eq(target).float().mean() # iou, precison, recall output = output > thres target = target > thres union = output | target inter = output & target _max = torch.tensor(1.0).to(output.device) union = max(union.sum().float(), _max) true_pos = max(inter.sum().float(), _max) vol_pred = max(output.sum().float(), _max) vol_gt = max(target.sum().float(), _max) return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt