|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.sparse as sp |
|
|
|
|
|
|
|
class LocalAffine(nn.Module): |
|
|
|
def __init__(self, num_points, batch_size=1, edges=None): |
|
''' |
|
specify the number of points, the number of points should be constant across the batch |
|
and the edges torch.Longtensor() with shape N * 2 |
|
the local affine operator supports batch operation |
|
batch size must be constant |
|
add additional pooling on top of w matrix |
|
''' |
|
super(LocalAffine, self).__init__() |
|
self.A = nn.Parameter( |
|
torch.eye(3).unsqueeze(0).unsqueeze(0).repeat( |
|
batch_size, num_points, 1, 1)) |
|
self.b = nn.Parameter( |
|
torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat( |
|
batch_size, num_points, 1, 1)) |
|
self.edges = edges |
|
self.num_points = num_points |
|
|
|
def stiffness(self): |
|
''' |
|
calculate the stiffness of local affine transformation |
|
f norm get infinity gradient when w is zero matrix, |
|
''' |
|
if self.edges is None: |
|
raise Exception("edges cannot be none when calculate stiff") |
|
idx1 = self.edges[:, 0] |
|
idx2 = self.edges[:, 1] |
|
affine_weight = torch.cat((self.A, self.b), dim=3) |
|
w1 = torch.index_select(affine_weight, dim=1, index=idx1) |
|
w2 = torch.index_select(affine_weight, dim=1, index=idx2) |
|
w_diff = (w1 - w2)**2 |
|
w_rigid = (torch.linalg.det(self.A) - 1.0)**2 |
|
return w_diff, w_rigid |
|
|
|
def forward(self, x, return_stiff=False): |
|
''' |
|
x should have shape of B * N * 3 |
|
''' |
|
x = x.unsqueeze(3) |
|
out_x = torch.matmul(self.A, x) |
|
out_x = out_x + self.b |
|
out_x.squeeze_(3) |
|
if return_stiff: |
|
stiffness, rigid = self.stiffness() |
|
return out_x, stiffness, rigid |
|
else: |
|
return out_x |
|
|