import torch | |
from torch import nn | |
class EmbeddingMLP(nn.Module): | |
def __init__(self, size=4): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(768 * size, 900 * size), | |
nn.BatchNorm1d(900 * size), | |
nn.ReLU(), | |
nn.Linear(900 * size, 300 * size) | |
) | |
def forward(self, data): | |
res = self.net(data) | |
return res | |
class PairClassifier(nn.Module): | |
def __init__(self, size=4): | |
super().__init__() | |
self.encoder = EmbeddingMLP(size) | |
self.net = nn.Sequential( | |
nn.Linear(300 * size * 2, 3000), | |
nn.ReLU(), | |
nn.Linear(3000, 1000), | |
nn.ReLU(), | |
nn.Linear(1000, 2), | |
) | |
def forward(self, data1, data2): | |
# modify the logic of loading the data | |
e1 = self.encoder(data1) | |
e2 = self.encoder(data2) | |
twins = torch.cat([e1, e2], dim=1) | |
res = self.net(twins) | |
return res | |