File size: 5,338 Bytes
2ac1c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import pytorch3d
import torch.nn.functional as F

from pytorch3d.ops import interpolate_face_attributes

from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    AmbientLights,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    SoftSilhouetteShader,
    HardPhongShader,
    TexturesVertex,
    TexturesUV,
    Materials,
)
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
from pytorch3d.renderer.mesh.shader import ShaderBase


def get_cos_angle(points, normals, camera_position):
    """
    calculate cosine similarity between view->surface and surface normal.
    """

    if points.shape != normals.shape:
        msg = "Expected points and normals to have the same shape: got %r, %r"
        raise ValueError(msg % (points.shape, normals.shape))

    # Ensure all inputs have same batch dimension as points
    matched_tensors = convert_to_tensors_and_broadcast(
        points, camera_position, device=points.device
    )
    _, camera_position = matched_tensors

    # Reshape direction and color so they have all the arbitrary intermediate
    # dimensions as points. Assume first dim = batch dim and last dim = 3.
    points_dims = points.shape[1:-1]
    expand_dims = (-1,) + (1,) * len(points_dims)

    if camera_position.shape != normals.shape:
        camera_position = camera_position.view(expand_dims + (3,))

    normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)

    # Calculate the cosine value.
    view_direction = camera_position - points
    view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
    cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
    cos_angle = cos_angle.clamp(0, 1)

    # Cosine of the angle between the reflected light ray and the viewer
    return cos_angle


def _geometry_shading_with_pixels(
    meshes, fragments, lights, cameras, materials, texels
):
    """
    Render pixel space vertex position, normal(world), depth, and cos angle

    Args:
            meshes: Batch of meshes
            fragments: Fragments named tuple with the outputs of rasterization
            lights: Lights class containing a batch of lights
            cameras: Cameras class containing a batch of cameras
            materials: Materials class containing a batch of material properties
            texels: texture per pixel of shape (N, H, W, K, 3)

    Returns:
            colors: (N, H, W, K, 3)
            pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
    """
    verts = meshes.verts_packed()  # (V, 3)
    faces = meshes.faces_packed()  # (F, 3)
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
    faces_verts = verts[faces]
    faces_normals = vertex_normals[faces]
    pixel_coords_in_camera = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_verts
    )
    pixel_normals = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_normals
    )

    cos_angles = get_cos_angle(
        pixel_coords_in_camera, pixel_normals, cameras.get_camera_center()
    )

    return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles


class HardGeometryShader(ShaderBase):
    """
    renders common geometric informations.


    """

    def forward(self, fragments, meshes, **kwargs):
        cameras = super()._get_cameras(**kwargs)
        texels = self.texel_from_uv(fragments, meshes)

        lights = kwargs.get("lights", self.lights)
        materials = kwargs.get("materials", self.materials)
        blend_params = kwargs.get("blend_params", self.blend_params)
        verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
            meshes=meshes,
            fragments=fragments,
            texels=texels,
            lights=lights,
            cameras=cameras,
            materials=materials,
        )
        texels = meshes.sample_textures(fragments)
        verts = hard_rgb_blend(verts, fragments, blend_params)
        normals = hard_rgb_blend(normals, fragments, blend_params)
        depths = hard_rgb_blend(depths, fragments, blend_params)
        cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
        from IPython import embed

        embed()
        texels = hard_rgb_blend(texels, fragments, blend_params)
        return verts, normals, depths, cos_angles, texels, fragments

    def texel_from_uv(self, fragments, meshes):
        texture_tmp = meshes.textures
        maps_tmp = texture_tmp.maps_padded()
        uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]]
        uv_color = (
            torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
        )
        uv_texture = TexturesUV(
            [uv_color.clone() for t in maps_tmp],
            texture_tmp.faces_uvs_padded(),
            texture_tmp.verts_uvs_padded(),
            sampling_mode="bilinear",
        )
        meshes.textures = uv_texture
        texels = meshes.sample_textures(fragments)
        meshes.textures = texture_tmp
        texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1)
        return texels