Spaces:
Build error
Build error
import torch | |
from torch import nn | |
import random | |
class ScaledDecoder(nn.Module): | |
def __init__(self, ninp, nhid, nout): | |
super().__init__() | |
self.linear = nn.Linear(ninp, nhid) | |
self.linear1 = nn.Linear(nhid, nout) | |
self.linear2 = nn.Linear(nhid, 10) | |
def forward(self, x): | |
#return torch.cat([self.linear1(x), self.linear2(x)], -1) | |
x = self.linear(x) | |
x = nn.GELU()(x) | |
temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device) | |
if random.random() > .99: | |
print(temps.shape,temps[:,:2]) | |
return self.linear1(x) / temps.unsqueeze(-1) | |
class FixedScaledDecoder(nn.Module): | |
def __init__(self, ninp, nhid, nout): | |
super().__init__() | |
self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)) | |
self.T = nn.Parameter(torch.ones(10000)/10000) | |
def forward(self, x): | |
return self.mapper(x)/self.T.sum() | |