CatVTON / densepose /modeling /cse /vertex_direct_embedder.py
zhengchong's picture
chore: Update dependencies and code structure
6eb1d7d
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
import pickle
import torch
from torch import nn
from detectron2.utils.file_io import PathManager
from .utils import normalize_embeddings
class VertexDirectEmbedder(nn.Module):
"""
Class responsible for embedding vertices. Vertex embeddings take
the form of a tensor of size [N, D], where
N = number of vertices
D = number of dimensions in the embedding space
"""
def __init__(self, num_vertices: int, embed_dim: int):
"""
Initialize embedder, set random embeddings
Args:
num_vertices (int): number of vertices to embed
embed_dim (int): number of dimensions in the embedding space
"""
super(VertexDirectEmbedder, self).__init__()
self.embeddings = nn.Parameter(torch.Tensor(num_vertices, embed_dim))
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
"""
Reset embeddings to random values
"""
self.embeddings.zero_()
def forward(self) -> torch.Tensor:
"""
Produce vertex embeddings, a tensor of shape [N, D] where:
N = number of vertices
D = number of dimensions in the embedding space
Return:
Full vertex embeddings, a tensor of shape [N, D]
"""
return normalize_embeddings(self.embeddings)
@torch.no_grad()
def load(self, fpath: str):
"""
Load data from a file
Args:
fpath (str): file path to load data from
"""
with PathManager.open(fpath, "rb") as hFile:
data = pickle.load(hFile)
for name in ["embeddings"]:
if name in data:
getattr(self, name).copy_(
torch.tensor(data[name]).float().to(device=getattr(self, name).device)
)