|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class LSTMModel(nn.Module): |
|
def __init__(self, embedding_matrix, hidden_size=256, num_layers=2, dropout=0.2): |
|
super(LSTMModel, self).__init__() |
|
num_embeddings, embedding_dim = embedding_matrix.shape |
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim) |
|
self.embedding.weight = nn.Parameter( |
|
torch.tensor(embedding_matrix, dtype=torch.float32) |
|
) |
|
self.embedding.weight.requires_grad = False |
|
|
|
self.lstm = nn.LSTM( |
|
input_size=embedding_matrix.shape[1], |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
batch_first=True, |
|
dropout=dropout, |
|
) |
|
self.fc = nn.Linear(hidden_size, 1) |
|
|
|
def forward(self, title, text): |
|
title_emb = self.embedding(title) |
|
text_emb = self.embedding(text) |
|
combined = torch.cat((title_emb, text_emb), dim=1) |
|
output, (hidden, _) = self.lstm(combined) |
|
out = self.fc(hidden[-1]) |
|
return torch.sigmoid(out) |
|
|