|
import torch |
|
import torch.nn as nn |
|
from src.shape_utils import load_shape_with_lbo |
|
|
|
|
|
|
|
class CSE(nn.Module): |
|
def __init__(self, class_name, num_basis=64, skip_first=True, dim=16, num_vert=None, barebones=False, device=torch.device('cuda'), rand_init=False): |
|
super(CSE, self).__init__() |
|
|
|
self.shape = load_shape_with_lbo(class_name, num_basis, skip_first) |
|
|
|
if barebones: |
|
return |
|
self.functional_basis=None |
|
|
|
|
|
if not rand_init: |
|
self.weight_matrix = nn.Parameter(torch.zeros(num_basis, dim, requires_grad=True)) |
|
else: |
|
self.weight_matrix = nn.Parameter(torch.randn(num_basis, dim, requires_grad=True)) |
|
|
|
self.to(device) |
|
self.nns = None |
|
self.num_vert = num_vert |
|
|
|
def forward(self): |
|
output = torch.matmul(self.functional_basis, self.weight_matrix) |
|
|
|
if self.num_vert is not None: |
|
output_tmp = torch.zeros((self.num_vert, output.shape[1])).to(output.device) |
|
output_tmp[:output.shape[0], :] = output |
|
output = output_tmp |
|
return output |
|
|
|
|