Image2Model / utils /pytorch3d_minimal.py
Daankular's picture
Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search
8f1bcd9
"""
pytorch3d_minimal.py
====================
Drop-in replacement for the pytorch3d subset used by PSHuman's project_mesh.py
and mesh_utils.py. Uses nvdiffrast for GPU rasterization.
Implements:
- Meshes / TexturesVertex
- look_at_view_transform
- FoVOrthographicCameras / OrthographicCameras (orthographic projection only)
- RasterizationSettings / MeshRasterizer (via nvdiffrast)
- render_pix2faces_py3d (compatibility shim)
"""
from __future__ import annotations
import math
import torch
import torch.nn.functional as F
import numpy as np
# ---------------------------------------------------------------------------
# Texture / Mesh containers
# ---------------------------------------------------------------------------
class TexturesVertex:
def __init__(self, verts_features):
# verts_features: list of [N, C] tensors (one per mesh in batch)
self._feats = verts_features
def verts_features_packed(self):
return self._feats[0]
def clone(self):
return TexturesVertex([f.clone() for f in self._feats])
def detach(self):
return TexturesVertex([f.detach() for f in self._feats])
def to(self, device):
self._feats = [f.to(device) for f in self._feats]
return self
class Meshes:
def __init__(self, verts, faces, textures=None):
self._verts = verts # list of [N,3] float tensors
self._faces = faces # list of [F,3] long tensors
self.textures = textures
# ---- accessors --------------------------------------------------------
def verts_padded(self): return torch.stack(self._verts)
def faces_padded(self): return torch.stack(self._faces)
def verts_packed(self): return self._verts[0]
def faces_packed(self): return self._faces[0]
def verts_list(self): return self._verts
def faces_list(self): return self._faces
def verts_normals_packed(self):
v, f = self._verts[0], self._faces[0]
v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=1)
fn = F.normalize(fn, dim=1)
vn = torch.zeros_like(v)
for k in range(3):
vn.scatter_add_(0, f[:, k:k+1].expand(-1, 3), fn)
return F.normalize(vn, dim=1)
# ---- device / copy ----------------------------------------------------
def to(self, device):
self._verts = [v.to(device) for v in self._verts]
self._faces = [f.to(device) for f in self._faces]
if self.textures is not None:
self.textures.to(device)
return self
def clone(self):
m = Meshes([v.clone() for v in self._verts],
[f.clone() for f in self._faces])
if self.textures is not None:
m.textures = self.textures.clone()
return m
def detach(self):
m = Meshes([v.detach() for v in self._verts],
[f.detach() for f in self._faces])
if self.textures is not None:
m.textures = self.textures.detach()
return m
# ---------------------------------------------------------------------------
# Camera math (mirrors pytorch3d look_at_view_transform + Orthographic)
# ---------------------------------------------------------------------------
def _look_at_rotation(camera_pos: torch.Tensor,
at: torch.Tensor,
up: torch.Tensor) -> torch.Tensor:
"""Return (3,3) rotation matrix: world → camera."""
z = F.normalize(camera_pos - at, dim=-1) # cam looks along -Z
x = F.normalize(torch.cross(up, z, dim=-1), dim=-1)
y = torch.cross(z, x, dim=-1)
R = torch.stack([x, y, z], dim=-1) # columns = cam axes
return R # shape (3,3)
def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0,
degrees=True, device="cpu"):
"""Matches pytorch3d convention exactly."""
if degrees:
elev = math.radians(float(elev))
azim = math.radians(float(azim))
# camera position in world
cx = dist * math.cos(elev) * math.sin(azim)
cy = dist * math.sin(elev)
cz = dist * math.cos(elev) * math.cos(azim)
eye = torch.tensor([[cx, cy, cz]], dtype=torch.float32, device=device)
at = torch.zeros(1, 3, device=device)
up = torch.tensor([[0, 1, 0]], dtype=torch.float32, device=device)
# pytorch3d stores R transposed (row = cam axis in world space)
R = _look_at_rotation(eye[0], at[0], up[0]).T.unsqueeze(0) # (1,3,3)
# T = camera position expressed in camera space
T = torch.bmm(-R, eye.unsqueeze(-1)).squeeze(-1) # (1,3)
return R, T
class _OrthoCamera:
"""Minimal orthographic camera, matches FoVOrthographicCameras API."""
def __init__(self, R, T, focal_length=1.0, device="cpu"):
self.R = R.to(device) # (B,3,3)
self.T = T.to(device) # (B,3)
self.focal = float(focal_length)
self.device = device
def to(self, device):
self.R = self.R.to(device)
self.T = self.T.to(device)
self.device = device
return self
def get_znear(self):
return torch.tensor(0.01, device=self.device)
def is_perspective(self):
return False
def transform_points_ndc(self, points):
"""
points: (B, N, 3) world coords
returns: (B, N, 3) NDC coords (X,Y in [-1,1], Z = depth)
"""
# world → camera
pts_cam = torch.bmm(points, self.R) + self.T.unsqueeze(1) # (B,N,3)
# orthographic NDC: scale by focal, flip Y to match image convention
ndc_x = pts_cam[..., 0] * self.focal
ndc_y = -pts_cam[..., 1] * self.focal # pytorch3d flips Y
ndc_z = pts_cam[..., 2]
return torch.stack([ndc_x, ndc_y, ndc_z], dim=-1)
def _world_to_clip(self, verts: torch.Tensor) -> torch.Tensor:
"""verts: (N,3) → clip (N,4) for nvdiffrast."""
pts_cam = (verts @ self.R[0].T) + self.T[0] # (N,3)
cx = pts_cam[:, 0] * self.focal
cy = -pts_cam[:, 1] * self.focal # flip Y
cz = pts_cam[:, 2]
w = torch.ones_like(cz)
return torch.stack([cx, cy, cz, w], dim=1) # (N,4)
# Aliases used in project_mesh.py
def FoVOrthographicCameras(device="cpu", R=None, T=None,
min_x=-1, max_x=1, min_y=-1, max_y=1,
focal_length=None, **kwargs):
fl = focal_length if focal_length is not None else 1.0 / (max_x + 1e-9)
return _OrthoCamera(R, T, focal_length=fl, device=device)
def FoVPerspectiveCameras(device="cpu", R=None, T=None, fov=60, degrees=True, **kwargs):
# Fallback: treat as orthographic at fov-derived scale (good enough for PSHuman)
fl = 1.0 / math.tan(math.radians(fov / 2)) if degrees else 1.0 / math.tan(fov / 2)
return _OrthoCamera(R, T, focal_length=fl, device=device)
OrthographicCameras = FoVOrthographicCameras
# ---------------------------------------------------------------------------
# Rasterizer (nvdiffrast-based)
# ---------------------------------------------------------------------------
class RasterizationSettings:
def __init__(self, image_size=512, blur_radius=0.0, faces_per_pixel=1):
if isinstance(image_size, (list, tuple)):
self.H, self.W = image_size[0], image_size[1]
else:
self.H = self.W = int(image_size)
class _Fragments:
def __init__(self, pix_to_face):
self.pix_to_face = pix_to_face.unsqueeze(-1) # (1,H,W,1)
class MeshRasterizer:
def __init__(self, cameras=None, raster_settings=None):
self.cameras = cameras
self.settings = raster_settings
self._glctx = None
def _get_ctx(self, device):
if self._glctx is None:
import nvdiffrast.torch as dr
self._glctx = dr.RasterizeCudaContext(device=device)
return self._glctx
def __call__(self, meshes: Meshes, cameras=None):
cam = cameras or self.cameras
H, W = self.settings.H, self.settings.W
device = meshes.verts_packed().device
import nvdiffrast.torch as dr
glctx = self._get_ctx(str(device))
verts = meshes.verts_packed().to(device)
faces = meshes.faces_packed().to(torch.int32).to(device)
clip = cam._world_to_clip(verts).unsqueeze(0) # (1,N,4)
rast, _ = dr.rasterize(glctx, clip, faces, resolution=(H, W))
pix_to_face = rast[0, :, :, -1].to(torch.int32) - 1 # -1 = background
return _Fragments(pix_to_face.unsqueeze(0))
# ---------------------------------------------------------------------------
# render_pix2faces_py3d shim (used in get_visible_faces)
# ---------------------------------------------------------------------------
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, **kwargs):
"""Returns {'pix_to_face': (1,H,W)} integer tensor of face indices (-1=bg)."""
settings = RasterizationSettings(image_size=(H, W))
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=settings)
frags = rasterizer(meshes)
return {"pix_to_face": frags.pix_to_face[..., 0]} # (1,H,W)