RepoSnipy / common /pair_classifier.py
Henry65's picture
Update similaritycal model
e0d476d
raw
history blame
No virus
985 Bytes
import torch
from torch import nn
class EmbeddingMLP(nn.Module):
def __init__(self, size=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768 * size, 900 * size),
nn.BatchNorm1d(900 * size),
nn.ReLU(),
nn.Linear(900 * size, 300 * size)
)
def forward(self, data):
res = self.net(data)
return res
class PairClassifier(nn.Module):
def __init__(self, size=4):
super().__init__()
self.encoder = EmbeddingMLP(size)
self.net = nn.Sequential(
nn.Linear(300 * size * 2, 3000),
nn.ReLU(),
nn.Linear(3000, 1000),
nn.ReLU(),
nn.Linear(1000, 2),
)
def forward(self, data1, data2):
# modify the logic of loading the data
e1 = self.encoder(data1)
e2 = self.encoder(data2)
twins = torch.cat([e1, e2], dim=1)
res = self.net(twins)
return res