Spaces:
Running on Zero
Running on Zero
| """ | |
| pytorch3d_minimal.py | |
| ==================== | |
| Drop-in replacement for the pytorch3d subset used by PSHuman's project_mesh.py | |
| and mesh_utils.py. Uses nvdiffrast for GPU rasterization. | |
| Implements: | |
| - Meshes / TexturesVertex | |
| - look_at_view_transform | |
| - FoVOrthographicCameras / OrthographicCameras (orthographic projection only) | |
| - RasterizationSettings / MeshRasterizer (via nvdiffrast) | |
| - render_pix2faces_py3d (compatibility shim) | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Texture / Mesh containers | |
| # --------------------------------------------------------------------------- | |
| class TexturesVertex: | |
| def __init__(self, verts_features): | |
| # verts_features: list of [N, C] tensors (one per mesh in batch) | |
| self._feats = verts_features | |
| def verts_features_packed(self): | |
| return self._feats[0] | |
| def clone(self): | |
| return TexturesVertex([f.clone() for f in self._feats]) | |
| def detach(self): | |
| return TexturesVertex([f.detach() for f in self._feats]) | |
| def to(self, device): | |
| self._feats = [f.to(device) for f in self._feats] | |
| return self | |
| class Meshes: | |
| def __init__(self, verts, faces, textures=None): | |
| self._verts = verts # list of [N,3] float tensors | |
| self._faces = faces # list of [F,3] long tensors | |
| self.textures = textures | |
| # ---- accessors -------------------------------------------------------- | |
| def verts_padded(self): return torch.stack(self._verts) | |
| def faces_padded(self): return torch.stack(self._faces) | |
| def verts_packed(self): return self._verts[0] | |
| def faces_packed(self): return self._faces[0] | |
| def verts_list(self): return self._verts | |
| def faces_list(self): return self._faces | |
| def verts_normals_packed(self): | |
| v, f = self._verts[0], self._faces[0] | |
| v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]] | |
| fn = torch.cross(v1 - v0, v2 - v0, dim=1) | |
| fn = F.normalize(fn, dim=1) | |
| vn = torch.zeros_like(v) | |
| for k in range(3): | |
| vn.scatter_add_(0, f[:, k:k+1].expand(-1, 3), fn) | |
| return F.normalize(vn, dim=1) | |
| # ---- device / copy ---------------------------------------------------- | |
| def to(self, device): | |
| self._verts = [v.to(device) for v in self._verts] | |
| self._faces = [f.to(device) for f in self._faces] | |
| if self.textures is not None: | |
| self.textures.to(device) | |
| return self | |
| def clone(self): | |
| m = Meshes([v.clone() for v in self._verts], | |
| [f.clone() for f in self._faces]) | |
| if self.textures is not None: | |
| m.textures = self.textures.clone() | |
| return m | |
| def detach(self): | |
| m = Meshes([v.detach() for v in self._verts], | |
| [f.detach() for f in self._faces]) | |
| if self.textures is not None: | |
| m.textures = self.textures.detach() | |
| return m | |
| # --------------------------------------------------------------------------- | |
| # Camera math (mirrors pytorch3d look_at_view_transform + Orthographic) | |
| # --------------------------------------------------------------------------- | |
| def _look_at_rotation(camera_pos: torch.Tensor, | |
| at: torch.Tensor, | |
| up: torch.Tensor) -> torch.Tensor: | |
| """Return (3,3) rotation matrix: world → camera.""" | |
| z = F.normalize(camera_pos - at, dim=-1) # cam looks along -Z | |
| x = F.normalize(torch.cross(up, z, dim=-1), dim=-1) | |
| y = torch.cross(z, x, dim=-1) | |
| R = torch.stack([x, y, z], dim=-1) # columns = cam axes | |
| return R # shape (3,3) | |
| def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0, | |
| degrees=True, device="cpu"): | |
| """Matches pytorch3d convention exactly.""" | |
| if degrees: | |
| elev = math.radians(float(elev)) | |
| azim = math.radians(float(azim)) | |
| # camera position in world | |
| cx = dist * math.cos(elev) * math.sin(azim) | |
| cy = dist * math.sin(elev) | |
| cz = dist * math.cos(elev) * math.cos(azim) | |
| eye = torch.tensor([[cx, cy, cz]], dtype=torch.float32, device=device) | |
| at = torch.zeros(1, 3, device=device) | |
| up = torch.tensor([[0, 1, 0]], dtype=torch.float32, device=device) | |
| # pytorch3d stores R transposed (row = cam axis in world space) | |
| R = _look_at_rotation(eye[0], at[0], up[0]).T.unsqueeze(0) # (1,3,3) | |
| # T = camera position expressed in camera space | |
| T = torch.bmm(-R, eye.unsqueeze(-1)).squeeze(-1) # (1,3) | |
| return R, T | |
| class _OrthoCamera: | |
| """Minimal orthographic camera, matches FoVOrthographicCameras API.""" | |
| def __init__(self, R, T, focal_length=1.0, device="cpu"): | |
| self.R = R.to(device) # (B,3,3) | |
| self.T = T.to(device) # (B,3) | |
| self.focal = float(focal_length) | |
| self.device = device | |
| def to(self, device): | |
| self.R = self.R.to(device) | |
| self.T = self.T.to(device) | |
| self.device = device | |
| return self | |
| def get_znear(self): | |
| return torch.tensor(0.01, device=self.device) | |
| def is_perspective(self): | |
| return False | |
| def transform_points_ndc(self, points): | |
| """ | |
| points: (B, N, 3) world coords | |
| returns: (B, N, 3) NDC coords (X,Y in [-1,1], Z = depth) | |
| """ | |
| # world → camera | |
| pts_cam = torch.bmm(points, self.R) + self.T.unsqueeze(1) # (B,N,3) | |
| # orthographic NDC: scale by focal, flip Y to match image convention | |
| ndc_x = pts_cam[..., 0] * self.focal | |
| ndc_y = -pts_cam[..., 1] * self.focal # pytorch3d flips Y | |
| ndc_z = pts_cam[..., 2] | |
| return torch.stack([ndc_x, ndc_y, ndc_z], dim=-1) | |
| def _world_to_clip(self, verts: torch.Tensor) -> torch.Tensor: | |
| """verts: (N,3) → clip (N,4) for nvdiffrast.""" | |
| pts_cam = (verts @ self.R[0].T) + self.T[0] # (N,3) | |
| cx = pts_cam[:, 0] * self.focal | |
| cy = -pts_cam[:, 1] * self.focal # flip Y | |
| cz = pts_cam[:, 2] | |
| w = torch.ones_like(cz) | |
| return torch.stack([cx, cy, cz, w], dim=1) # (N,4) | |
| # Aliases used in project_mesh.py | |
| def FoVOrthographicCameras(device="cpu", R=None, T=None, | |
| min_x=-1, max_x=1, min_y=-1, max_y=1, | |
| focal_length=None, **kwargs): | |
| fl = focal_length if focal_length is not None else 1.0 / (max_x + 1e-9) | |
| return _OrthoCamera(R, T, focal_length=fl, device=device) | |
| def FoVPerspectiveCameras(device="cpu", R=None, T=None, fov=60, degrees=True, **kwargs): | |
| # Fallback: treat as orthographic at fov-derived scale (good enough for PSHuman) | |
| fl = 1.0 / math.tan(math.radians(fov / 2)) if degrees else 1.0 / math.tan(fov / 2) | |
| return _OrthoCamera(R, T, focal_length=fl, device=device) | |
| OrthographicCameras = FoVOrthographicCameras | |
| # --------------------------------------------------------------------------- | |
| # Rasterizer (nvdiffrast-based) | |
| # --------------------------------------------------------------------------- | |
| class RasterizationSettings: | |
| def __init__(self, image_size=512, blur_radius=0.0, faces_per_pixel=1): | |
| if isinstance(image_size, (list, tuple)): | |
| self.H, self.W = image_size[0], image_size[1] | |
| else: | |
| self.H = self.W = int(image_size) | |
| class _Fragments: | |
| def __init__(self, pix_to_face): | |
| self.pix_to_face = pix_to_face.unsqueeze(-1) # (1,H,W,1) | |
| class MeshRasterizer: | |
| def __init__(self, cameras=None, raster_settings=None): | |
| self.cameras = cameras | |
| self.settings = raster_settings | |
| self._glctx = None | |
| def _get_ctx(self, device): | |
| if self._glctx is None: | |
| import nvdiffrast.torch as dr | |
| self._glctx = dr.RasterizeCudaContext(device=device) | |
| return self._glctx | |
| def __call__(self, meshes: Meshes, cameras=None): | |
| cam = cameras or self.cameras | |
| H, W = self.settings.H, self.settings.W | |
| device = meshes.verts_packed().device | |
| import nvdiffrast.torch as dr | |
| glctx = self._get_ctx(str(device)) | |
| verts = meshes.verts_packed().to(device) | |
| faces = meshes.faces_packed().to(torch.int32).to(device) | |
| clip = cam._world_to_clip(verts).unsqueeze(0) # (1,N,4) | |
| rast, _ = dr.rasterize(glctx, clip, faces, resolution=(H, W)) | |
| pix_to_face = rast[0, :, :, -1].to(torch.int32) - 1 # -1 = background | |
| return _Fragments(pix_to_face.unsqueeze(0)) | |
| # --------------------------------------------------------------------------- | |
| # render_pix2faces_py3d shim (used in get_visible_faces) | |
| # --------------------------------------------------------------------------- | |
| def render_pix2faces_py3d(meshes, cameras, H=512, W=512, **kwargs): | |
| """Returns {'pix_to_face': (1,H,W)} integer tensor of face indices (-1=bg).""" | |
| settings = RasterizationSettings(image_size=(H, W)) | |
| rasterizer = MeshRasterizer(cameras=cameras, raster_settings=settings) | |
| frags = rasterizer(meshes) | |
| return {"pix_to_face": frags.pix_to_face[..., 0]} # (1,H,W) | |