| |
| |
|
|
| import os |
| import torch |
| import numpy as np |
| import torch.nn as nn |
| from pytorch3d.io import load_obj |
| from pytorch3d.structures import ( |
| Meshes, Pointclouds |
| ) |
| from pytorch3d.renderer import ( |
| PerspectiveCameras, OrthographicCameras, |
| look_at_view_transform, RasterizationSettings, |
| FoVPerspectiveCameras, PointsRasterizationSettings, |
| PointsRenderer, PointsRasterizer, AlphaCompositor, |
| PointLights, AmbientLights, TexturesVertex, TexturesUV, BlendParams, |
| SoftPhongShader, MeshRasterizer, MeshRenderer, SoftSilhouetteShader, HardFlatShader |
| ) |
| from pytorch3d.transforms import matrix_to_rotation_6d, rotation_6d_to_matrix |
| from .util import weak_cam2persp_cam |
|
|
|
|
| class BaseMeshRenderer(nn.Module): |
| def __init__(self, assets_dir, image_size=512, device='cuda', skin_color=[252, 224, 203], bg_color=[0, 0, 0], focal_length=12): |
| super(BaseMeshRenderer, self).__init__() |
| self.device = device |
| self.image_size = image_size |
| self.assets_dir = assets_dir |
| self.skin_color = np.array(skin_color) |
| self.bg_color = bg_color |
| self.focal_length = focal_length |
|
|
| self.raster_settings = RasterizationSettings(image_size=image_size, blur_radius=0.0, faces_per_pixel=1) |
| self.lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) |
| self.manual_lights = PointLights( |
| device=self.device, |
| location=((0.0, 0.0, 5.0), ), |
| ambient_color=((0.5, 0.5, 0.5), ), |
| diffuse_color=((0.5, 0.5, 0.5), ), |
| specular_color=((0.01, 0.01, 0.01), ) |
| ) |
| self.blend = BlendParams(background_color=np.array(bg_color)/225.) |
|
|
| def _build_cameras(self, transform_matrix, focal_length): |
| batch_size = transform_matrix.shape[0] |
| screen_size = torch.tensor( |
| [self.image_size, self.image_size], device=self.device |
| ).float()[None].repeat(batch_size, 1) |
| cameras_kwargs = { |
| 'principal_point': torch.zeros(batch_size, 2, device=self.device).float(), 'focal_length': focal_length, |
| 'image_size': screen_size, 'device': self.device, |
| } |
| cameras = PerspectiveCameras(**cameras_kwargs, R=transform_matrix[:, :3, :3], T=transform_matrix[:, :3, 3]) |
| return cameras |
| |
| def _build_orth_cameras(self): |
| R, T = look_at_view_transform(dist=10) |
| return OrthographicCameras(device=self.device, focal_length=1, R=R, T=T, |
| image_size=torch.tensor([self.image_size, self.image_size], device=self.device).float()) |
|
|
|
|
| def forward(self, vertices, faces=None, landmarks={}, cameras=None, transform_matrix=None, focal_length=None, is_weak_cam=False, ret_image=True): |
| B, V = vertices.shape[:2] |
| focal_length = self.focal_length if focal_length is None else focal_length |
| if isinstance(cameras, torch.Tensor): |
| cameras = cameras.clone() |
| elif is_weak_cam: |
| cameras = self._build_orth_cameras() |
| elif cameras is None: |
| cameras = self._build_cameras(transform_matrix, focal_length) |
| |
| t_faces = faces[None].repeat(B, 1, 1) |
| |
| ret_vertices = cameras.transform_points_screen(vertices) |
| ret_landmarks = {k: cameras.transform_points_screen(v) for k,v in landmarks.items()} |
|
|
| images = None |
| if ret_image: |
| |
| verts_rgb = torch.from_numpy(self.skin_color/255).float().to(self.device)[None, None, :].repeat(B, V, 1) |
| textures = TexturesVertex(verts_features=verts_rgb) |
| mesh = Meshes( |
| verts=vertices.to(self.device), |
| faces=t_faces.to(self.device), |
| textures=textures |
| ) |
| renderer = MeshRenderer( |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=self.raster_settings), |
| shader=SoftPhongShader(cameras=cameras, lights=self.lights, device=self.device, blend_params=self.blend) |
| ) |
| render_results = renderer(mesh).permute(0, 3, 1, 2) |
| images = render_results[:, :3] |
| alpha_images = render_results[:, 3:] |
| images[alpha_images.expand(-1, 3, -1, -1)<0.5] = 0.0 |
| images = images * 255 |
| |
| return ret_vertices, ret_landmarks, images |
|
|
| def render_mesh(self, vertices,cameras, faces=None,lights=None): |
| |
| |
| |
|
|
| if faces is None: |
| faces = self.faces.squeeze(0) |
| |
| self.lights=lights |
| if lights is None: |
| self.lights = PointLights(device=self.device, location=[[0.0, 1.0, 10.0]]) |
|
|
| B, V = vertices.shape[:2] |
| |
| t_faces = faces[None].repeat(B, 1, 1) |
| |
| images = None |
| |
| verts_rgb = torch.from_numpy(self.skin_color/255).float().to(self.device)[None, None, :].repeat(B, V, 1) |
| textures = TexturesVertex(verts_features=verts_rgb) |
| mesh = Meshes( |
| verts=vertices.to(self.device), |
| faces=t_faces.to(self.device), |
| textures=textures |
| ) |
| |
| |
| |
| |
| renderer = MeshRenderer( |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=self.raster_settings), |
| shader=SoftPhongShader(cameras=cameras, lights=self.lights, device=self.device, blend_params=self.blend) |
| ) |
| render_results = renderer(mesh).permute(0, 3, 1, 2) |
| images = render_results[:, :3] |
| alpha_images = render_results[:, 3:] |
| images[alpha_images.expand(-1, 3, -1, -1)<0.5] = 0.0 |
| images = images * 255 |
| |
| return images |
| def render_alpha(self, vertices,cameras, faces=None,lights=None): |
|
|
| if faces is None: |
| faces = self.faces.squeeze(0) |
| |
| self.lights=lights |
| if lights is None: |
| self.lights = PointLights(device=self.device, location=[[0.0, 1.0, 10.0]]) |
|
|
| B, V = vertices.shape[:2] |
| |
| t_faces = faces[None].repeat(B, 1, 1) |
| |
| images = None |
| |
| verts_rgb = torch.from_numpy(self.skin_color/255).float().to(self.device)[None, None, :].repeat(B, V, 1) |
| textures = TexturesVertex(verts_features=verts_rgb) |
| mesh = Meshes( |
| verts=vertices.to(self.device), |
| faces=t_faces.to(self.device), |
| textures=textures |
| ) |
|
|
| renderer = MeshRenderer( |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=self.raster_settings), |
| shader=SoftPhongShader(cameras=cameras, lights=self.lights, device=self.device, blend_params=self.blend) |
| ) |
| render_results = renderer(mesh).permute(0, 3, 1, 2) |
|
|
| alpha_images = render_results[:, 3:] |
| return alpha_images |
| |
| class PointRenderer(nn.Module): |
| def __init__(self, image_size=256, device='cpu'): |
| super(PointRenderer, self).__init__() |
| self.device = device |
| R, T = look_at_view_transform(4, 30, 30) |
| self.cameras = FoVPerspectiveCameras(device=device, R=R, T=T, znear=0.01, zfar=1.0) |
| raster_settings = PointsRasterizationSettings( |
| image_size=image_size, radius=0.005, points_per_pixel=10 |
| ) |
| rasterizer = PointsRasterizer(cameras=self.cameras, raster_settings=raster_settings) |
| self.renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor()) |
| |
| def forward(self, points, D=3, E=15, A=30, coords=True, ex_points=None): |
| if D !=8 or E != 30 or A != 30: |
| R, T = look_at_view_transform(D, E, A) |
| self.cameras = FoVPerspectiveCameras(device=self.device, R=R, T=T, znear=0.01, zfar=1.0) |
| verts = torch.Tensor(points).to(self.device) |
| verts = verts[:, torch.randperm(verts.shape[1])[:10000]] |
| if ex_points is not None: |
| verts = torch.cat([verts, ex_points.expand(verts.shape[0], -1, -1)], dim=1) |
| if coords: |
| coords_size = verts.shape[1]//10 |
| cod = verts.new_zeros(coords_size*3, 3) |
| li = torch.linspace(0, 1.0, steps=coords_size, device=cod.device) |
| cod[:coords_size, 0], cod[coords_size:coords_size*2, 1], cod[coords_size*2:, 2] = li, li, li |
| verts = torch.cat( |
| [verts, cod.unsqueeze(0).expand(verts.shape[0], -1, -1)], dim=1 |
| ) |
| rgb = torch.Tensor(torch.rand_like(verts)).to(self.device) |
| point_cloud = Pointclouds(points=verts, features=rgb) |
| images = self.renderer(point_cloud, cameras=self.cameras,).permute(0, 3, 1, 2) |
| return images*255 |
|
|
|
|
| class TextureRenderer(nn.Module): |
| def __init__(self, obj_filename=None, tuv=None, flame_mask=None, device='cpu'): |
| super(TextureRenderer, self).__init__() |
| self.device = device |
| |
| if obj_filename is not None: |
| _, faces, aux = load_obj(obj_filename, load_textures=False) |
| self.uvverts = aux.verts_uvs[None, ...].to(self.device) |
| self.uvfaces = faces.textures_idx[None, ...].to(self.device) |
| self.faces = faces.verts_idx[None, ...].to(self.device) |
| elif tuv is not None: |
| import numpy as np |
| self.uvverts = tuv['verts_uvs'][None, ...].to(self.device) |
| self.uvfaces = tuv['textures_idx'][None, ...].to(self.device) |
| self.faces = tuv['verts_idx'][None, ...].to(self.device) |
| else: |
| raise NotImplementedError('Must have faces and uvs.') |
| |
| self.lights = AmbientLights(device=self.device) |
| |
| if flame_mask is not None: |
| reduced_faces = [] |
| for f in self.faces[0]: |
| valid = 0 |
| for v in f: |
| if v.item() in flame_mask: |
| valid += 1 |
| reduced_faces.append(True if valid == 3 else False) |
| reduced_faces = torch.tensor(reduced_faces).to(self.device) |
| self.flame_mask = reduced_faces |
| |
| pi = np.pi |
| sh_const = torch.tensor( |
| [ |
| 1 / np.sqrt(4 * pi), |
| ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), |
| ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), |
| ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), |
| (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), |
| (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), |
| (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), |
| (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), |
| (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))), |
| ], |
| dtype=torch.float32, |
| ) |
| self.constant_factor = sh_const.to(self.device) |
|
|
| def add_SHlight(self, normal_images, sh_coeff): |
| |
| N = normal_images |
| sh = torch.stack([ |
| N[:, 0] * 0. + 1., N[:, 0], N[:, 1], |
| N[:, 2], N[:, 0] * N[:, 1], N[:, 0] * N[:, 2], |
| N[:, 1] * N[:, 2], N[:, 0] ** 2 - N[:, 1] ** 2, 3 * (N[:, 2] ** 2) - 1 |
| ], 1) |
| sh = sh * self.constant_factor[None, :, None, None] |
| shading = torch.sum(sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1) |
| return shading |
|
|
| def _build_cameras(self, transform_matrix, focal_length, principal_point, image_size): |
| batch_size = transform_matrix.shape[0] |
| screen_size = torch.tensor( |
| [image_size, image_size], device=self.device |
| ).float()[None].repeat(batch_size, 1) |
| cameras_kwargs = { |
| 'principal_point': principal_point.repeat(batch_size, 1), 'focal_length': focal_length, |
| 'image_size': screen_size, 'device': self.device, |
| } |
| cameras = PerspectiveCameras(**cameras_kwargs, R=transform_matrix[:, :3, :3], T=transform_matrix[:, :3, 3]) |
| return cameras |
|
|
| def forward( |
| self, vertices_world, texture_images, lights=None, image_size=512, |
| cameras=None, transform_matrix=None, focal_length=None, principal_point=None |
| ): |
| if cameras is None: |
| cameras = self._build_cameras(transform_matrix, focal_length, principal_point, image_size) |
| batch_size = vertices_world.shape[0] |
| faces = self.faces.expand(batch_size, -1, -1) |
| textures_uv = TexturesUV( |
| maps=texture_images.expand(batch_size, -1, -1, -1).permute(0, 2, 3, 1), |
| faces_uvs=self.uvfaces.expand(batch_size, -1, -1), |
| verts_uvs=self.uvverts.expand(batch_size, -1, -1) |
| ) |
| meshes_world = Meshes(verts=vertices_world, faces=faces, textures=textures_uv) |
| |
| raster_settings = RasterizationSettings( |
| image_size=image_size, blur_radius=0.0, faces_per_pixel=1, |
| perspective_correct=True, cull_backfaces=True |
| ) |
| phong_renderer = MeshRenderer( |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), |
| shader=SoftPhongShader(device=self.device, cameras=cameras, lights=self.lights) |
| ) |
| image_ref = phong_renderer(meshes_world=meshes_world) |
| images = image_ref[..., :3].permute(0, 3, 1, 2) |
| masks_all = image_ref[..., 3:].permute(0, 3, 1, 2) > 0.0 |
| if lights is not None: |
| images = self.add_SHlight(images, lights) |
| images[~masks_all.expand(-1, 3, -1, -1)] = 0.0 |
| |
| with torch.no_grad(): |
| if hasattr(self, 'flame_mask'): |
| textures_verts = TexturesVertex(verts_features=vertices_world.new_ones(vertices_world.shape)) |
| meshes_masked = Meshes( |
| verts=vertices_world, faces=faces[:, self.flame_mask], textures=textures_verts |
| ) |
| silhouette_renderer = MeshRenderer( |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), |
| shader=SoftSilhouetteShader() |
| ) |
| masks_face = silhouette_renderer(meshes_world=meshes_masked) |
| masks_face = masks_face[..., 3:].permute(0, 3, 1, 2) > 0.0 |
| else: |
| masks_face = None |
| return images, masks_all, masks_face |
|
|