Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import logging | |
import numpy as np | |
import pickle | |
from enum import Enum | |
from typing import Optional | |
import torch | |
from torch import nn | |
from detectron2.config import CfgNode | |
from detectron2.utils.file_io import PathManager | |
from .vertex_direct_embedder import VertexDirectEmbedder | |
from .vertex_feature_embedder import VertexFeatureEmbedder | |
class EmbedderType(Enum): | |
""" | |
Embedder type which defines how vertices are mapped into the embedding space: | |
- "vertex_direct": direct vertex embedding | |
- "vertex_feature": embedding vertex features | |
""" | |
VERTEX_DIRECT = "vertex_direct" | |
VERTEX_FEATURE = "vertex_feature" | |
def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: | |
""" | |
Create an embedder based on the provided configuration | |
Args: | |
embedder_spec (CfgNode): embedder configuration | |
embedder_dim (int): embedding space dimensionality | |
Return: | |
An embedder instance for the specified configuration | |
Raises ValueError, in case of unexpected embedder type | |
""" | |
embedder_type = EmbedderType(embedder_spec.TYPE) | |
if embedder_type == EmbedderType.VERTEX_DIRECT: | |
embedder = VertexDirectEmbedder( | |
num_vertices=embedder_spec.NUM_VERTICES, | |
embed_dim=embedder_dim, | |
) | |
if embedder_spec.INIT_FILE != "": | |
embedder.load(embedder_spec.INIT_FILE) | |
elif embedder_type == EmbedderType.VERTEX_FEATURE: | |
embedder = VertexFeatureEmbedder( | |
num_vertices=embedder_spec.NUM_VERTICES, | |
feature_dim=embedder_spec.FEATURE_DIM, | |
embed_dim=embedder_dim, | |
train_features=embedder_spec.FEATURES_TRAINABLE, | |
) | |
if embedder_spec.INIT_FILE != "": | |
embedder.load(embedder_spec.INIT_FILE) | |
else: | |
raise ValueError(f"Unexpected embedder type {embedder_type}") | |
if not embedder_spec.IS_TRAINABLE: | |
embedder.requires_grad_(False) | |
return embedder | |
class Embedder(nn.Module): | |
""" | |
Embedder module that serves as a container for embedders to use with different | |
meshes. Extends Module to automatically save / load state dict. | |
""" | |
DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." | |
def __init__(self, cfg: CfgNode): | |
""" | |
Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule | |
"embedder_{i}". | |
Args: | |
cfg (CfgNode): configuration options | |
""" | |
super(Embedder, self).__init__() | |
self.mesh_names = set() | |
embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE | |
logger = logging.getLogger(__name__) | |
for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): | |
logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") | |
self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) | |
self.mesh_names.add(mesh_name) | |
if cfg.MODEL.WEIGHTS != "": | |
self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) | |
def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): | |
if prefix is None: | |
prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX | |
state_dict = None | |
if fpath.endswith(".pkl"): | |
with PathManager.open(fpath, "rb") as hFile: | |
state_dict = pickle.load(hFile, encoding="latin1") | |
else: | |
with PathManager.open(fpath, "rb") as hFile: | |
state_dict = torch.load(hFile, map_location=torch.device("cpu")) | |
if state_dict is not None and "model" in state_dict: | |
state_dict_local = {} | |
for key in state_dict["model"]: | |
if key.startswith(prefix): | |
v_key = state_dict["model"][key] | |
if isinstance(v_key, np.ndarray): | |
v_key = torch.from_numpy(v_key) | |
state_dict_local[key[len(prefix) :]] = v_key | |
# non-strict loading to finetune on different meshes | |
self.load_state_dict(state_dict_local, strict=False) | |
def forward(self, mesh_name: str) -> torch.Tensor: | |
""" | |
Produce vertex embeddings for the specific mesh; vertex embeddings are | |
a tensor of shape [N, D] where: | |
N = number of vertices | |
D = number of dimensions in the embedding space | |
Args: | |
mesh_name (str): name of a mesh for which to obtain vertex embeddings | |
Return: | |
Vertex embeddings, a tensor of shape [N, D] | |
""" | |
return getattr(self, f"embedder_{mesh_name}")() | |
def has_embeddings(self, mesh_name: str) -> bool: | |
return hasattr(self, f"embedder_{mesh_name}") | |