svjack's picture
Upload folder using huggingface_hub
d015578 verified
import torch.nn as nn
from spiga.models.gnn.layers import MLP
from spiga.models.gnn.gat import GAT
class StepRegressor(nn.Module):
def __init__(self, input_dim: int, feature_dim: int, nstack=4, decoding=[256, 128, 64, 32]):
super(StepRegressor, self).__init__()
assert nstack > 0
self.nstack = nstack
self.gat = nn.ModuleList([GAT(input_dim, feature_dim, 4)])
for _ in range(nstack-1):
self.gat.append(GAT(feature_dim, feature_dim, 4))
self.decoder = OffsetDecoder(feature_dim, decoding)
def forward(self, embedded, prob_list=[]):
embedded = embedded.transpose(-1, -2)
for i in range(self.nstack):
embedded, prob = self.gat[i](embedded)
prob_list.append(prob)
offset = self.decoder(embedded)
return offset.transpose(-1, -2), prob_list
class OffsetDecoder(nn.Module):
def __init__(self, feature_dim, layers):
super().__init__()
self.decoder = MLP([feature_dim] + layers + [2])
def forward(self, embedded):
return self.decoder(embedded)
class RelativePositionEncoder(nn.Module):
def __init__(self, input_dim, feature_dim, layers):
super().__init__()
self.encoder = MLP([input_dim] + layers + [feature_dim])
def forward(self, feature):
feature = feature.transpose(-1, -2)
return self.encoder(feature).transpose(-1, -2)