| import torch | |
| from torch import nn | |
| class EmbeddingMLP(nn.Module): | |
| def __init__(self, size=4): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(768 * size, 900 * size), | |
| nn.BatchNorm1d(900 * size), | |
| nn.ReLU(), | |
| nn.Linear(900 * size, 300 * size) | |
| ) | |
| def forward(self, data): | |
| res = self.net(data) | |
| return res | |
| class PairClassifier(nn.Module): | |
| def __init__(self, size=4): | |
| super().__init__() | |
| self.encoder = EmbeddingMLP(size) | |
| self.net = nn.Sequential( | |
| nn.Linear(300 * size * 2, 3000), | |
| nn.ReLU(), | |
| nn.Linear(3000, 1000), | |
| nn.ReLU(), | |
| nn.Linear(1000, 2), | |
| ) | |
| def forward(self, data): | |
| e1 = self.encoder(data[:, :768 * 4]) | |
| e2 = self.encoder(data[:, 768 * 4:]) | |
| twins = torch.cat([e1, e2], dim=1) | |
| res = self.net(twins) | |
| return res | |