|
|
|
|
|
import pickle |
|
from functools import lru_cache |
|
from typing import Dict, Optional, Tuple |
|
import torch |
|
|
|
from detectron2.utils.file_io import PathManager |
|
|
|
from densepose.data.meshes.catalog import MeshCatalog, MeshInfo |
|
|
|
|
|
def _maybe_copy_to_device( |
|
attribute: Optional[torch.Tensor], device: torch.device |
|
) -> Optional[torch.Tensor]: |
|
if attribute is None: |
|
return None |
|
return attribute.to(device) |
|
|
|
|
|
class Mesh: |
|
def __init__( |
|
self, |
|
vertices: Optional[torch.Tensor] = None, |
|
faces: Optional[torch.Tensor] = None, |
|
geodists: Optional[torch.Tensor] = None, |
|
symmetry: Optional[Dict[str, torch.Tensor]] = None, |
|
texcoords: Optional[torch.Tensor] = None, |
|
mesh_info: Optional[MeshInfo] = None, |
|
device: Optional[torch.device] = None, |
|
): |
|
""" |
|
Args: |
|
vertices (tensor [N, 3] of float32): vertex coordinates in 3D |
|
faces (tensor [M, 3] of long): triangular face represented as 3 |
|
vertex indices |
|
geodists (tensor [N, N] of float32): geodesic distances from |
|
vertex `i` to vertex `j` (optional, default: None) |
|
symmetry (dict: str -> tensor): various mesh symmetry data: |
|
- "vertex_transforms": vertex mapping under horizontal flip, |
|
tensor of size [N] of type long; vertex `i` is mapped to |
|
vertex `tensor[i]` (optional, default: None) |
|
texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global |
|
and normalized mesh UVs (optional, default: None) |
|
mesh_info (MeshInfo type): necessary to load the attributes on-the-go, |
|
can be used instead of passing all the variables one by one |
|
device (torch.device): device of the Mesh. If not provided, will use |
|
the device of the vertices |
|
""" |
|
self._vertices = vertices |
|
self._faces = faces |
|
self._geodists = geodists |
|
self._symmetry = symmetry |
|
self._texcoords = texcoords |
|
self.mesh_info = mesh_info |
|
self.device = device |
|
|
|
assert self._vertices is not None or self.mesh_info is not None |
|
|
|
all_fields = [self._vertices, self._faces, self._geodists, self._texcoords] |
|
|
|
if self.device is None: |
|
for field in all_fields: |
|
if field is not None: |
|
self.device = field.device |
|
break |
|
if self.device is None and symmetry is not None: |
|
for key in symmetry: |
|
self.device = symmetry[key].device |
|
break |
|
self.device = torch.device("cpu") if self.device is None else self.device |
|
|
|
assert all([var.device == self.device for var in all_fields if var is not None]) |
|
if symmetry: |
|
assert all(symmetry[key].device == self.device for key in symmetry) |
|
if texcoords and vertices: |
|
assert len(vertices) == len(texcoords) |
|
|
|
def to(self, device: torch.device): |
|
device_symmetry = self._symmetry |
|
if device_symmetry: |
|
device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()} |
|
return Mesh( |
|
_maybe_copy_to_device(self._vertices, device), |
|
_maybe_copy_to_device(self._faces, device), |
|
_maybe_copy_to_device(self._geodists, device), |
|
device_symmetry, |
|
_maybe_copy_to_device(self._texcoords, device), |
|
self.mesh_info, |
|
device, |
|
) |
|
|
|
@property |
|
def vertices(self): |
|
if self._vertices is None and self.mesh_info is not None: |
|
self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device) |
|
return self._vertices |
|
|
|
@property |
|
def faces(self): |
|
if self._faces is None and self.mesh_info is not None: |
|
self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device) |
|
return self._faces |
|
|
|
@property |
|
def geodists(self): |
|
if self._geodists is None and self.mesh_info is not None: |
|
self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device) |
|
return self._geodists |
|
|
|
@property |
|
def symmetry(self): |
|
if self._symmetry is None and self.mesh_info is not None: |
|
self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device) |
|
return self._symmetry |
|
|
|
@property |
|
def texcoords(self): |
|
if self._texcoords is None and self.mesh_info is not None: |
|
self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device) |
|
return self._texcoords |
|
|
|
def get_geodists(self): |
|
if self.geodists is None: |
|
self.geodists = self._compute_geodists() |
|
return self.geodists |
|
|
|
def _compute_geodists(self): |
|
|
|
geodists = None |
|
return geodists |
|
|
|
|
|
def load_mesh_data( |
|
mesh_fpath: str, field: str, device: Optional[torch.device] = None |
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
with PathManager.open(mesh_fpath, "rb") as hFile: |
|
|
|
|
|
return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device) |
|
return None |
|
|
|
|
|
def load_mesh_auxiliary_data( |
|
fpath: str, device: Optional[torch.device] = None |
|
) -> Optional[torch.Tensor]: |
|
fpath_local = PathManager.get_local_path(fpath) |
|
with PathManager.open(fpath_local, "rb") as hFile: |
|
return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device) |
|
return None |
|
|
|
|
|
@lru_cache() |
|
def load_mesh_symmetry( |
|
symmetry_fpath: str, device: Optional[torch.device] = None |
|
) -> Optional[Dict[str, torch.Tensor]]: |
|
with PathManager.open(symmetry_fpath, "rb") as hFile: |
|
symmetry_loaded = pickle.load(hFile) |
|
symmetry = { |
|
"vertex_transforms": torch.as_tensor( |
|
symmetry_loaded["vertex_transforms"], dtype=torch.long |
|
).to(device), |
|
} |
|
return symmetry |
|
return None |
|
|
|
|
|
@lru_cache() |
|
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh: |
|
return Mesh(mesh_info=MeshCatalog[mesh_name], device=device) |
|
|