AniMerPlus / amr /utils /mesh_renderer.py
luoxue-star's picture
Init Commit
1966925
import os
if 'PYOPENGL_PLATFORM' not in os.environ:
os.environ['PYOPENGL_PLATFORM'] = 'egl'
import torch
from torchvision.utils import make_grid
import numpy as np
import pyrender
import trimesh
import cv2
import math
import torch.nn.functional as F
from typing import List, Tuple
from pytorch3d.renderer import RasterizationSettings, MeshRasterizer, SoftSilhouetteShader, BlendParams
from pytorch3d.renderer import MeshRenderer as Pytorch3dMeshRender
from pytorch3d.structures import Meshes
from pytorch3d.renderer import PerspectiveCameras, FoVPerspectiveCameras
def create_raymond_lights():
import pyrender
thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
nodes = []
for phi, theta in zip(phis, thetas):
xp = np.sin(theta) * np.cos(phi)
yp = np.sin(theta) * np.sin(phi)
zp = np.cos(theta)
z = np.array([xp, yp, zp])
z = z / np.linalg.norm(z)
x = np.array([-z[1], z[0], 0.0])
if np.linalg.norm(x) == 0:
x = np.array([1.0, 0.0, 0.0])
x = x / np.linalg.norm(x)
y = np.cross(z, x)
matrix = np.eye(4)
matrix[:3, :3] = np.c_[x, y, z]
nodes.append(pyrender.Node(
light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
matrix=matrix
))
return nodes
def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
"""
Compute rectangle enclosing keypoints above the threshold.
Args:
keypoints (np.array): Keypoint array of shape (N, 3).
threshold (float): Confidence visualization threshold.
Returns:
Tuple[float, float, float]: Rectangle width, height and area.
"""
valid_ind = keypoints[:, -1] > threshold
if valid_ind.sum() > 0:
valid_keypoints = keypoints[valid_ind][:, :-1]
max_x = valid_keypoints[:, 0].max()
max_y = valid_keypoints[:, 1].max()
min_x = valid_keypoints[:, 0].min()
min_y = valid_keypoints[:, 1].min()
width = max_x - min_x
height = max_y - min_y
area = width * height
return width, height, area
else:
return 0, 0, 0
def render_keypoint(img: np.array, keypoint: np.array, threshold=0.1,
use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0) -> np.array:
if use_confidence and map_fn is not None:
thicknessCircleRatioRight = 1. / 50 * map_fn(keypoint[:, -1])
else:
thicknessCircleRatioRight = 1. / 50 * np.ones(keypoint.shape[0])
thicknessLineRatioWRTCircle = 0.75
if keypoint.shape[0] == 26:
pairs = [0, 24, 1, 24, 2, 24, 3, 14, 4, 15, 5, 16, 6, 17, 7, 18, 8, 12, 9, 13, 10, 7, 11, 7,
12, 18, 13, 18, 14, 8, 15, 9, 16, 10, 17, 11, 18, 24, 19, 25, 20, 0, 21, 1, 22, 24,
23, 24, 25, 7]
elif keypoint.shape[0] == 18:
pairs = [9, 8, 8, 2, 2, 3, 3, 4, 2, 0, 2, 1, 4, 5,
5, 14, 14, 15, 4, 6, 6, 7, 7, 11, 11, 10,
7, 13, 13, 12, 5, 16, 5, 17]
else:
raise ValueError("Keypoint shape not supported")
pairs = np.array(pairs).reshape(-1, 2) if pairs is not None else None
colors = [255., 0., 85.,
255., 0., 0.,
255., 85., 0.,
255., 170., 0.,
255., 255., 0.,
170., 255., 0.,
85., 255., 0.,
0., 255., 0.,
255., 0., 0.,
0., 255., 85.,
0., 255., 170.,
0., 255., 255.,
0., 170., 255.,
0., 85., 255.,
0., 0., 255.,
255., 0., 170.,
170., 0., 255.,
255., 0., 255.,
85., 0., 255.,
0., 0., 255.,
0., 0., 255.,
0., 0., 255.,
0., 255., 255.,
0., 255., 255.,
0., 255., 255.,
255., 225., 255.]
colors = np.array(colors).reshape(-1, 3)
poseScales = [1]
img_orig = img.copy()
width, height = img.shape[1], img.shape[2]
area = width * height
lineType = 8
shift = 0
numberColors = len(colors)
thresholdRectangle = 0.1
animal_width, animal_height, animal_area = get_keypoints_rectangle(keypoint, thresholdRectangle)
if animal_area > 0:
ratioAreas = min(1, max(animal_width / width, animal_height / height))
thicknessRatio = np.maximum(np.round(math.sqrt(area) * thicknessCircleRatioRight * ratioAreas), 2)
thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
thicknessLine = np.maximum(1, np.round(thicknessRatio * thicknessLineRatioWRTCircle))
radius = thicknessRatio / 2
else:
return img
img = np.ascontiguousarray(img.copy())
if pairs is not None:
for i, pair in enumerate(pairs):
index1, index2 = pair
if keypoint[index1, -1] > threshold and keypoint[index2, -1] > threshold:
thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * poseScales[0]))
colorIndex = index2
color = colors[colorIndex % numberColors]
keypoint1 = keypoint[index1, :-1].astype(np.int32)
keypoint2 = keypoint[index2, :-1].astype(np.int32)
cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()),
thicknessLineScaled, lineType, shift)
for part in range(len(keypoint)):
faceIndex = part
if keypoint[faceIndex, -1] > threshold:
radiusScaled = int(round(radius[faceIndex] * poseScales[0]))
thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * poseScales[0]))
colorIndex = part
color = colors[colorIndex % numberColors]
center = keypoint[faceIndex, :-1].astype(np.int32)
cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled,
lineType, shift)
return img
class MeshRenderer:
def __init__(self, cfg, faces=None):
self.cfg = cfg
self.img_res = cfg.MODEL.IMAGE_SIZE
self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
viewport_height=self.img_res,
point_size=1.0)
self.camera_center = [self.img_res // 2, self.img_res // 2]
self.faces = faces
def visualize(self, vertices, camera_translation, images, focal_length, nrow=3, padding=2):
images_np = np.transpose(images, (0, 2, 3, 1))
rend_imgs = []
for i in range(vertices.shape[0]):
rend_img = torch.from_numpy(np.transpose(
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False),
(2, 0, 1))).float()
rend_img_side = torch.from_numpy(np.transpose(
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True),
(2, 0, 1))).float()
rend_imgs.append(torch.from_numpy(images[i]))
rend_imgs.append(rend_img)
rend_imgs.append(rend_img_side)
rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
return rend_imgs
def visualize_tensorboard(self, vertices, camera_translation, images, focal_length, pred_keypoints, gt_keypoints,
pred_masks=None, gt_masks=None):
images_np = np.transpose(images, (0, 2, 3, 1))
rend_imgs = []
pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
pred_keypoints = self.img_res * (pred_keypoints + 0.5)
gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
# keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9),
# (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
# rend_img_pytorch3d = self.render_by_pytorch3d(vertices, camera_translation,
# images_np, focal_length=self.focal_length)
for i in range(vertices.shape[0]):
rend_img = torch.from_numpy(np.transpose(
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False),
(2, 0, 1))).float()
rend_img_side = torch.from_numpy(np.transpose(
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True),
(2, 0, 1))).float()
keypoints = pred_keypoints[i]
pred_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255
keypoints = gt_keypoints[i]
gt_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255
rend_imgs.append(torch.from_numpy(images[i]))
rend_imgs.append(rend_img)
rend_imgs.append(rend_img_side)
if pred_masks is not None:
rend_imgs.append(torch.from_numpy(pred_masks[i]))
if gt_masks is not None:
rend_imgs.append(torch.from_numpy(gt_masks[i]))
rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2, 0, 1))
rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2, 0, 1))
return rend_imgs
def __call__(self, vertices, camera_translation, image, focal_length, text=None, resize=None, side_view=False,
baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
viewport_height=image.shape[0],
point_size=1.0)
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.0,
alphaMode='OPAQUE',
baseColorFactor=baseColorFactor)
camera_translation[0] *= -1.
mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
if side_view:
rot = trimesh.transformations.rotation_matrix(
np.radians(rot_angle), [0, 1, 0])
mesh.apply_transform(rot)
rot = trimesh.transformations.rotation_matrix(
np.radians(180), [1, 0, 0])
mesh.apply_transform(rot)
mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
ambient_light=(0.3, 0.3, 0.3))
scene.add(mesh, 'mesh')
camera_pose = np.eye(4)
camera_pose[:3, 3] = camera_translation
camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
cx=camera_center[0], cy=camera_center[1],
zfar=1000)
scene.add(camera, pose=camera_pose)
light_nodes = create_raymond_lights()
for node in light_nodes:
scene.add_node(node)
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
color = color.astype(np.float32) / 255.0
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
if not side_view:
output_img = (color[:, :, :3] * valid_mask +
(1 - valid_mask) * image)
else:
output_img = color[:, :, :3]
if resize is not None:
output_img = cv2.resize(output_img, resize)
output_img = output_img.astype(np.float32)
renderer.delete()
return output_img
class SilhouetteRenderer:
# silhouette_renderer = base_renderer(size=256, focal=focal, device=device)
def __init__(self, size, focal=None, fov=None, device='cpu'):
self.device = device
self.size = size
self.camera = self.init_camera(focal, fov)
self.silhouette_renderer = self.init_silhouette_renderer()
self.R = torch.tensor([[-1, 0, 0],
[0, -1, 0],
[0, 0, 1]], dtype=torch.float32).repeat(1,1,1).to(device)
self.T = torch.zeros(3, dtype=torch.float32).repeat(1, 1).to(device)
def init_camera(self, focal, fov):
if fov is None:
fov = 2 * np.arctan (self.size/(focal * 2)) * 180 / np.pi
camera = FoVPerspectiveCameras(zfar=350, fov=fov, device=self.device)
return camera
def init_silhouette_renderer(self):
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
image_size = self.size,
blur_radius = np.log(1./1e-4 - 1.) * blend_params.sigma,
faces_per_pixel=100,
bin_size=0)
# Create a silhouette mesh renderer by composing a rasterizer and a shader.
silhouette_renderer = Pytorch3dMeshRender(
rasterizer=MeshRasterizer(
cameras=self.camera,
raster_settings=raster_settings
),
shader=SoftSilhouetteShader(blend_params=blend_params)
)
return silhouette_renderer
@torch.autocast(device_type='cuda', dtype=torch.float32)
def __call__(self, vertices, faces):
''' Right now only render silhouettes and depths
Input:
vertices: BN * V * 3
faces: BN * F * 3
'''
# Split vertices to avoid out-of-memory
silhouettes = []
for vertice_, face_ in zip(vertices, faces):
torch_mesh = Meshes(verts=vertice_.unsqueeze(0),
faces=face_.unsqueeze(0))
silhouette = self.silhouette_renderer(meshes_world=torch_mesh.clone(),
R=self.R, T=self.T)[...,3]
silhouettes.append(silhouette)
return torch.cat(silhouettes, dim=0)