|  |  | 
					
						
						|  |  | 
					
						
						|  | import pickle | 
					
						
						|  | from functools import lru_cache | 
					
						
						|  | from typing import Dict, Optional, Tuple | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from detectron2.utils.file_io import PathManager | 
					
						
						|  |  | 
					
						
						|  | from densepose.data.meshes.catalog import MeshCatalog, MeshInfo | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _maybe_copy_to_device( | 
					
						
						|  | attribute: Optional[torch.Tensor], device: torch.device | 
					
						
						|  | ) -> Optional[torch.Tensor]: | 
					
						
						|  | if attribute is None: | 
					
						
						|  | return None | 
					
						
						|  | return attribute.to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Mesh: | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | vertices: Optional[torch.Tensor] = None, | 
					
						
						|  | faces: Optional[torch.Tensor] = None, | 
					
						
						|  | geodists: Optional[torch.Tensor] = None, | 
					
						
						|  | symmetry: Optional[Dict[str, torch.Tensor]] = None, | 
					
						
						|  | texcoords: Optional[torch.Tensor] = None, | 
					
						
						|  | mesh_info: Optional[MeshInfo] = None, | 
					
						
						|  | device: Optional[torch.device] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | vertices (tensor [N, 3] of float32): vertex coordinates in 3D | 
					
						
						|  | faces (tensor [M, 3] of long): triangular face represented as 3 | 
					
						
						|  | vertex indices | 
					
						
						|  | geodists (tensor [N, N] of float32): geodesic distances from | 
					
						
						|  | vertex `i` to vertex `j` (optional, default: None) | 
					
						
						|  | symmetry (dict: str -> tensor): various mesh symmetry data: | 
					
						
						|  | - "vertex_transforms": vertex mapping under horizontal flip, | 
					
						
						|  | tensor of size [N] of type long; vertex `i` is mapped to | 
					
						
						|  | vertex `tensor[i]` (optional, default: None) | 
					
						
						|  | texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global | 
					
						
						|  | and normalized mesh UVs (optional, default: None) | 
					
						
						|  | mesh_info (MeshInfo type): necessary to load the attributes on-the-go, | 
					
						
						|  | can be used instead of passing all the variables one by one | 
					
						
						|  | device (torch.device): device of the Mesh. If not provided, will use | 
					
						
						|  | the device of the vertices | 
					
						
						|  | """ | 
					
						
						|  | self._vertices = vertices | 
					
						
						|  | self._faces = faces | 
					
						
						|  | self._geodists = geodists | 
					
						
						|  | self._symmetry = symmetry | 
					
						
						|  | self._texcoords = texcoords | 
					
						
						|  | self.mesh_info = mesh_info | 
					
						
						|  | self.device = device | 
					
						
						|  |  | 
					
						
						|  | assert self._vertices is not None or self.mesh_info is not None | 
					
						
						|  |  | 
					
						
						|  | all_fields = [self._vertices, self._faces, self._geodists, self._texcoords] | 
					
						
						|  |  | 
					
						
						|  | if self.device is None: | 
					
						
						|  | for field in all_fields: | 
					
						
						|  | if field is not None: | 
					
						
						|  | self.device = field.device | 
					
						
						|  | break | 
					
						
						|  | if self.device is None and symmetry is not None: | 
					
						
						|  | for key in symmetry: | 
					
						
						|  | self.device = symmetry[key].device | 
					
						
						|  | break | 
					
						
						|  | self.device = torch.device("cpu") if self.device is None else self.device | 
					
						
						|  |  | 
					
						
						|  | assert all([var.device == self.device for var in all_fields if var is not None]) | 
					
						
						|  | if symmetry: | 
					
						
						|  | assert all(symmetry[key].device == self.device for key in symmetry) | 
					
						
						|  | if texcoords and vertices: | 
					
						
						|  | assert len(vertices) == len(texcoords) | 
					
						
						|  |  | 
					
						
						|  | def to(self, device: torch.device): | 
					
						
						|  | device_symmetry = self._symmetry | 
					
						
						|  | if device_symmetry: | 
					
						
						|  | device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()} | 
					
						
						|  | return Mesh( | 
					
						
						|  | _maybe_copy_to_device(self._vertices, device), | 
					
						
						|  | _maybe_copy_to_device(self._faces, device), | 
					
						
						|  | _maybe_copy_to_device(self._geodists, device), | 
					
						
						|  | device_symmetry, | 
					
						
						|  | _maybe_copy_to_device(self._texcoords, device), | 
					
						
						|  | self.mesh_info, | 
					
						
						|  | device, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def vertices(self): | 
					
						
						|  | if self._vertices is None and self.mesh_info is not None: | 
					
						
						|  | self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device) | 
					
						
						|  | return self._vertices | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def faces(self): | 
					
						
						|  | if self._faces is None and self.mesh_info is not None: | 
					
						
						|  | self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device) | 
					
						
						|  | return self._faces | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def geodists(self): | 
					
						
						|  | if self._geodists is None and self.mesh_info is not None: | 
					
						
						|  | self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device) | 
					
						
						|  | return self._geodists | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def symmetry(self): | 
					
						
						|  | if self._symmetry is None and self.mesh_info is not None: | 
					
						
						|  | self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device) | 
					
						
						|  | return self._symmetry | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def texcoords(self): | 
					
						
						|  | if self._texcoords is None and self.mesh_info is not None: | 
					
						
						|  | self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device) | 
					
						
						|  | return self._texcoords | 
					
						
						|  |  | 
					
						
						|  | def get_geodists(self): | 
					
						
						|  | if self.geodists is None: | 
					
						
						|  | self.geodists = self._compute_geodists() | 
					
						
						|  | return self.geodists | 
					
						
						|  |  | 
					
						
						|  | def _compute_geodists(self): | 
					
						
						|  |  | 
					
						
						|  | geodists = None | 
					
						
						|  | return geodists | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_mesh_data( | 
					
						
						|  | mesh_fpath: str, field: str, device: Optional[torch.device] = None | 
					
						
						|  | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | 
					
						
						|  | with PathManager.open(mesh_fpath, "rb") as hFile: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device) | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_mesh_auxiliary_data( | 
					
						
						|  | fpath: str, device: Optional[torch.device] = None | 
					
						
						|  | ) -> Optional[torch.Tensor]: | 
					
						
						|  | fpath_local = PathManager.get_local_path(fpath) | 
					
						
						|  | with PathManager.open(fpath_local, "rb") as hFile: | 
					
						
						|  | return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device) | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @lru_cache() | 
					
						
						|  | def load_mesh_symmetry( | 
					
						
						|  | symmetry_fpath: str, device: Optional[torch.device] = None | 
					
						
						|  | ) -> Optional[Dict[str, torch.Tensor]]: | 
					
						
						|  | with PathManager.open(symmetry_fpath, "rb") as hFile: | 
					
						
						|  | symmetry_loaded = pickle.load(hFile) | 
					
						
						|  | symmetry = { | 
					
						
						|  | "vertex_transforms": torch.as_tensor( | 
					
						
						|  | symmetry_loaded["vertex_transforms"], dtype=torch.long | 
					
						
						|  | ).to(device), | 
					
						
						|  | } | 
					
						
						|  | return symmetry | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @lru_cache() | 
					
						
						|  | def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh: | 
					
						
						|  | return Mesh(mesh_info=MeshCatalog[mesh_name], device=device) | 
					
						
						|  |  |