| import torch | |
| import torch.nn as nn | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.ll1 = nn.Linear(768, 1024) | |
| self.bn1 = nn.BatchNorm1d(2) | |
| self.elu1 = nn.ELU() | |
| self.ll2 = nn.Linear(1024, 512) | |
| self.bn2 = nn.BatchNorm1d(2) | |
| self.elu2 = nn.ELU() | |
| self.llf = nn.Linear(512, 1) | |
| def forward(self, x): | |
| x = self.elu1(self.bn1(self.ll1(x))) | |
| x = self.elu2(self.bn2(self.ll2(x))) | |
| x = torch.sum(x, dim=1) | |
| x = self.llf(x) | |
| return x | |
| if __name__ == '__main__': | |
| model = torch.load('model.pth') |