Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from .fourier_features import FourierFeatures | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
hidden_features: int = None, | |
out_features: int = None, | |
) -> None: | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features * 4 | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = nn.GELU(approximate="tanh") | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
torch.nn.init.kaiming_normal_( | |
self.fc1.weight, mode="fan_in", nonlinearity="relu" | |
) | |
torch.nn.init.kaiming_normal_( | |
self.fc2.weight, mode="fan_in", nonlinearity="relu" | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
return x | |
class RegionModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.coordinate_features = FourierFeatures(1, 256) | |
self.coordinate_encoder = nn.Linear(256, 2048) | |
self.size_features = FourierFeatures(2, 512) | |
self.size_encoder = nn.Linear(512, 2048) | |
self.coordinate_decoder = MLP(2048, 8192, 1024) | |
self.size_decoder = MLP(2048, 8192, 2048) | |
def encode_coordinate(self, coordinate): | |
return self.coordinate_encoder(self.coordinate_features(coordinate)) | |
def encode_size(self, size): | |
return self.size_encoder(self.size_features(size)) | |
def decode_coordinate(self, logit): | |
return self.coordinate_decoder(logit) | |
def decode_size(self, logit): | |
o = self.size_decoder(logit) | |
return o.view(-1, 2, 1024) | |
def encode(self, position, size): | |
c = self.encode_coordinate(position.view(2, 1)).view(2, 2048) | |
return torch.stack([c[0], c[1], self.encode_size(size)], dim=0) | |
def decode(self, position_logits, size_logits): | |
return ( | |
self.decode_coordinate(position_logits), | |
self.decode_size(size_logits), | |
) | |