Spaces:
Running
Running
import torch.nn as nn | |
from spiga.models.gnn.layers import MLP | |
from spiga.models.gnn.gat import GAT | |
class StepRegressor(nn.Module): | |
def __init__(self, input_dim: int, feature_dim: int, nstack=4, decoding=[256, 128, 64, 32]): | |
super(StepRegressor, self).__init__() | |
assert nstack > 0 | |
self.nstack = nstack | |
self.gat = nn.ModuleList([GAT(input_dim, feature_dim, 4)]) | |
for _ in range(nstack-1): | |
self.gat.append(GAT(feature_dim, feature_dim, 4)) | |
self.decoder = OffsetDecoder(feature_dim, decoding) | |
def forward(self, embedded, prob_list=[]): | |
embedded = embedded.transpose(-1, -2) | |
for i in range(self.nstack): | |
embedded, prob = self.gat[i](embedded) | |
prob_list.append(prob) | |
offset = self.decoder(embedded) | |
return offset.transpose(-1, -2), prob_list | |
class OffsetDecoder(nn.Module): | |
def __init__(self, feature_dim, layers): | |
super().__init__() | |
self.decoder = MLP([feature_dim] + layers + [2]) | |
def forward(self, embedded): | |
return self.decoder(embedded) | |
class RelativePositionEncoder(nn.Module): | |
def __init__(self, input_dim, feature_dim, layers): | |
super().__init__() | |
self.encoder = MLP([input_dim] + layers + [feature_dim]) | |
def forward(self, feature): | |
feature = feature.transpose(-1, -2) | |
return self.encoder(feature).transpose(-1, -2) | |