| import nvdiffrast.torch as dr |
| import torch |
| from matplotlib import image |
|
|
|
|
| def _warmup(glctx): |
| |
| def tensor(*args, **kwargs): |
| return torch.tensor(*args, device="cuda", **kwargs) |
|
|
| pos = tensor( |
| [[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], |
| dtype=torch.float32, |
| ) |
| tri = tensor([[0, 1, 2]], dtype=torch.int32) |
| dr.rasterize(glctx, pos, tri, resolution=[256, 256]) |
|
|
|
|
| class NormalsRenderer: |
|
|
| def __init__( |
| self, |
| mv: torch.Tensor, |
| proj: torch.Tensor, |
| image_size: tuple[int, int], |
| ): |
| self._mvp = proj @ mv |
| self._image_size = image_size |
| |
| self._glctx = dr.RasterizeCudaContext() |
| _warmup(self._glctx) |
|
|
| def render( |
| self, |
| vertices: torch.Tensor, |
| normals: torch.Tensor, |
| faces: torch.Tensor, |
| ) -> torch.Tensor: |
|
|
| V = vertices.shape[0] |
| faces = faces.type(torch.int32) |
| vert_hom = torch.cat( |
| (vertices, torch.ones(V, 1, device=vertices.device)), axis=-1 |
| ) |
| vertices_clip = vert_hom @ self._mvp.transpose(-2, -1) |
| rast_out, _ = dr.rasterize( |
| self._glctx, |
| vertices_clip, |
| faces, |
| resolution=self._image_size, |
| grad_db=False, |
| ) |
| vert_col = (normals + 1) / 2 |
| col, _ = dr.interpolate(vert_col, rast_out, faces) |
| alpha = torch.clamp(rast_out[..., -1:], max=1) |
| col = torch.concat((col, alpha), dim=-1) |
| col = dr.antialias(col, rast_out, vertices_clip, faces) |
| return col |
|
|