Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class ResBlock(nn.Module): | |
| """Linear block with residuals""" | |
| def __init__(self, ch): | |
| super().__init__() | |
| self.join = nn.ReLU() | |
| self.long = nn.Sequential( | |
| nn.Linear(ch, ch), | |
| nn.LeakyReLU(0.1), | |
| nn.Linear(ch, ch), | |
| nn.LeakyReLU(0.1), | |
| nn.Linear(ch, ch), | |
| ) | |
| def forward(self, x): | |
| return self.join(self.long(x) + x) | |
| class PredictorModel(nn.Module): | |
| """Main predictor class""" | |
| def __init__(self, features=768, outputs=1, hidden=1024): | |
| super().__init__() | |
| self.features = features | |
| self.outputs = outputs | |
| self.hidden = hidden | |
| self.up = nn.Sequential( | |
| nn.Linear(self.features, self.hidden), | |
| ResBlock(ch=self.hidden), | |
| ) | |
| self.down = nn.Sequential( | |
| nn.Linear(self.hidden, 128), | |
| nn.Linear(128, 64), | |
| nn.Dropout(0.1), | |
| nn.LeakyReLU(), | |
| nn.Linear(64, 32), | |
| nn.Linear(32, self.outputs), | |
| ) | |
| self.out = nn.Softmax(dim=1) if self.outputs > 1 else nn.Tanh() | |
| def forward(self, x): | |
| y = self.up(x) | |
| z = self.down(y) | |
| if self.outputs > 1: | |
| return self.out(z) | |
| else: | |
| return (self.out(z)+1.0)/2.0 | |