kimic's picture
Initial commit
c5cd586
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_size=256, num_layers=2, dropout=0.2):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True, dropout=dropout)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, title, text):
title_emb = self.embedding(title)
text_emb = self.embedding(text)
combined = torch.cat((title_emb, text_emb), dim=1)
output, (hidden, _) = self.lstm(combined)
out = self.fc(hidden[-1])
return torch.sigmoid(out)