Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,338 Bytes
2ac1c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import pytorch3d
import torch.nn.functional as F
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
AmbientLights,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
SoftSilhouetteShader,
HardPhongShader,
TexturesVertex,
TexturesUV,
Materials,
)
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
from pytorch3d.renderer.mesh.shader import ShaderBase
def get_cos_angle(points, normals, camera_position):
"""
calculate cosine similarity between view->surface and surface normal.
"""
if points.shape != normals.shape:
msg = "Expected points and normals to have the same shape: got %r, %r"
raise ValueError(msg % (points.shape, normals.shape))
# Ensure all inputs have same batch dimension as points
matched_tensors = convert_to_tensors_and_broadcast(
points, camera_position, device=points.device
)
_, camera_position = matched_tensors
# Reshape direction and color so they have all the arbitrary intermediate
# dimensions as points. Assume first dim = batch dim and last dim = 3.
points_dims = points.shape[1:-1]
expand_dims = (-1,) + (1,) * len(points_dims)
if camera_position.shape != normals.shape:
camera_position = camera_position.view(expand_dims + (3,))
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
# Calculate the cosine value.
view_direction = camera_position - points
view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
cos_angle = cos_angle.clamp(0, 1)
# Cosine of the angle between the reflected light ray and the viewer
return cos_angle
def _geometry_shading_with_pixels(
meshes, fragments, lights, cameras, materials, texels
):
"""
Render pixel space vertex position, normal(world), depth, and cos angle
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
lights: Lights class containing a batch of lights
cameras: Cameras class containing a batch of cameras
materials: Materials class containing a batch of material properties
texels: texture per pixel of shape (N, H, W, K, 3)
Returns:
colors: (N, H, W, K, 3)
pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
"""
verts = meshes.verts_packed() # (V, 3)
faces = meshes.faces_packed() # (F, 3)
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_verts = verts[faces]
faces_normals = vertex_normals[faces]
pixel_coords_in_camera = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts
)
pixel_normals = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_normals
)
cos_angles = get_cos_angle(
pixel_coords_in_camera, pixel_normals, cameras.get_camera_center()
)
return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles
class HardGeometryShader(ShaderBase):
"""
renders common geometric informations.
"""
def forward(self, fragments, meshes, **kwargs):
cameras = super()._get_cameras(**kwargs)
texels = self.texel_from_uv(fragments, meshes)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
texels = meshes.sample_textures(fragments)
verts = hard_rgb_blend(verts, fragments, blend_params)
normals = hard_rgb_blend(normals, fragments, blend_params)
depths = hard_rgb_blend(depths, fragments, blend_params)
cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
from IPython import embed
embed()
texels = hard_rgb_blend(texels, fragments, blend_params)
return verts, normals, depths, cos_angles, texels, fragments
def texel_from_uv(self, fragments, meshes):
texture_tmp = meshes.textures
maps_tmp = texture_tmp.maps_padded()
uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]]
uv_color = (
torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
)
uv_texture = TexturesUV(
[uv_color.clone() for t in maps_tmp],
texture_tmp.faces_uvs_padded(),
texture_tmp.verts_uvs_padded(),
sampling_mode="bilinear",
)
meshes.textures = uv_texture
texels = meshes.sample_textures(fragments)
meshes.textures = texture_tmp
texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1)
return texels
|