|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class LSTMModel(nn.Module): |
|
def __init__(self, vocab_size, embedding_dim=128, hidden_size=256, num_layers=2, dropout=0.2): |
|
super(LSTMModel, self).__init__() |
|
self.embedding = nn.Embedding( |
|
num_embeddings=vocab_size, embedding_dim=embedding_dim) |
|
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, |
|
num_layers=num_layers, batch_first=True, dropout=dropout) |
|
self.fc = nn.Linear(hidden_size, 1) |
|
|
|
def forward(self, title, text): |
|
title_emb = self.embedding(title) |
|
text_emb = self.embedding(text) |
|
combined = torch.cat((title_emb, text_emb), dim=1) |
|
output, (hidden, _) = self.lstm(combined) |
|
out = self.fc(hidden[-1]) |
|
return torch.sigmoid(out) |
|
|