File size: 5,213 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# This script is borrowed and extended from
from packaging import version
import torch
import scipy
import os
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from lib.common.config import cfg
from lib.pymaf.utils.geometry import projection
from lib.pymaf.core.path_config import MESH_DOWNSAMPLEING
import logging
logger = logging.getLogger(__name__)
class MAF_Extractor(nn.Module):
''' Mesh-aligned Feature Extrator
As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices.
The features extrated from spatial feature maps will go through a MLP for dimension reduction.
def __init__(self, device=torch.device('cuda')):
self.device = device
self.filters = []
self.num_views = 1
filter_channels = cfg.MODEL.PyMAF.MLP_DIM
self.last_op = nn.ReLU(True)
for l in range(0, len(filter_channels) - 1):
if 0 != l:
nn.Conv1d(filter_channels[l] + filter_channels[0],
filter_channels[l + 1], 1))
nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))
self.add_module("conv%d" % l, self.filters[l])
self.im_feat = None = None
# downsample SMPL mesh and assign part labels
# from
smpl_mesh_graph = np.load(MESH_DOWNSAMPLEING,
A = smpl_mesh_graph['A']
U = smpl_mesh_graph['U']
D = smpl_mesh_graph['D'] # shape: (2,)
# downsampling
ptD = []
for i in range(len(D)):
d = scipy.sparse.coo_matrix(D[i])
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(
ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
# downsampling mapping from 6890 points to 431 points
# ptD[0].to_dense() - Size: [1723, 6890]
# ptD[1].to_dense() - Size: [431. 1723]
Dmap = torch.matmul(ptD[1].to_dense(),
ptD[0].to_dense()) # 6890 -> 431
self.register_buffer('Dmap', Dmap)
def reduce_dim(self, feature):
Dimension reduction by multi-layer perceptrons
:param feature: list of [B, C_s, N] point-wise features before dimension reduction
:return: [B, C_p x N] concatantion of point-wise features after dimension reduction
y = feature
tmpy = feature
for i, f in enumerate(self.filters):
y = self._modules['conv' +
str(i)](y if i == 0 else[y, tmpy], 1))
if i != len(self.filters) - 1:
y = F.leaky_relu(y)
if self.num_views > 1 and i == len(self.filters) // 2:
y = y.view(-1, self.num_views, y.shape[1],
tmpy = feature.view(-1, self.num_views, feature.shape[1],
y = self.last_op(y)
y = y.view(y.shape[0], -1)
return y
def sampling(self, points, im_feat=None, z_feat=None):
Given 2D points, sample the point-wise features for each point,
the dimension of point-wise features will be reduced from C_s to C_p by MLP.
Image features should be pre-computed before this call.
:param points: [B, N, 2] image coordinates of points
:im_feat: [B, C_s, H_s, W_s] spatial feature maps
:return: [B, C_p x N] concatantion of point-wise features after dimension reduction
if im_feat is None:
im_feat = self.im_feat
batch_size = im_feat.shape[0]
if version.parse(torch.__version__) >= version.parse('1.3.0'):
# Default grid_sample behavior has changed to align_corners=False since 1.3.0.
point_feat = torch.nn.functional.grid_sample(
im_feat, points.unsqueeze(2), align_corners=True)[..., 0]
point_feat = torch.nn.functional.grid_sample(
im_feat, points.unsqueeze(2))[..., 0]
mesh_align_feat = self.reduce_dim(point_feat)
return mesh_align_feat
def forward(self, p, s_feat=None, cam=None, **kwargs):
''' Returns mesh-aligned features for the 3D mesh points.
p (tensor): [B, N_m, 3] mesh vertices
s_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps
cam (tensor): [B, 3] camera
mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features
if cam is None:
cam =
p_proj_2d = projection(p, cam, retain_z=False)
mesh_align_feat = self.sampling(p_proj_2d, s_feat)
return mesh_align_feat