Spaces:
Runtime error
Runtime error
File size: 2,809 Bytes
d945eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import slangtorch
import torch
import torch.nn as nn
from jaxtyping import Bool, Float
from torch import Tensor
class TextureBaker(nn.Module):
def __init__(self):
super().__init__()
self.baker = slangtorch.loadModule(
os.path.join(os.path.dirname(__file__), "texture_baker.slang")
)
def rasterize(
self,
uv: Float[Tensor, "Nv 2"],
face_indices: Float[Tensor, "Nf 3"],
bake_resolution: int,
) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
if not face_indices.is_cuda or not uv.is_cuda:
raise ValueError("All input tensors must be on cuda")
face_indices = face_indices.to(torch.int32)
uv = uv.to(torch.float32)
rast_result = torch.empty(
bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
)
block_size = 16
grid_size = bake_resolution // block_size
self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
)
return rast_result
def get_mask(
self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
) -> Bool[Tensor, "bake_resolution bake_resolution"]:
return rast[..., -1] >= 0
def interpolate(
self,
attr: Float[Tensor, "Nv 3"],
rast: Float[Tensor, "bake_resolution bake_resolution 4"],
face_indices: Float[Tensor, "Nf 3"],
uv: Float[Tensor, "Nv 2"],
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
# Make sure all input tensors are on torch
if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
raise ValueError("All input tensors must be on cuda")
attr = attr.to(torch.float32)
face_indices = face_indices.to(torch.int32)
uv = uv.to(torch.float32)
pos_bake = torch.zeros(
rast.shape[0],
rast.shape[1],
3,
device=attr.device,
dtype=attr.dtype,
)
block_size = 16
grid_size = rast.shape[0] // block_size
self.baker.interpolate(
attr=attr, indices=face_indices, rast=rast, output=pos_bake
).launchRaw(
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
)
return pos_bake
def forward(
self,
attr: Float[Tensor, "Nv 3"],
uv: Float[Tensor, "Nv 2"],
face_indices: Float[Tensor, "Nf 3"],
bake_resolution: int,
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
rast = self.rasterize(uv, face_indices, bake_resolution)
return self.interpolate(attr, rast, face_indices, uv)
|