import torch import torch.nn as nn class LSTMModel(nn.Module): def __init__(self, embedding_matrix, hidden_size=256, num_layers=2, dropout=0.2): super(LSTMModel, self).__init__() num_embeddings, embedding_dim = embedding_matrix.shape self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.embedding.weight = nn.Parameter( torch.tensor(embedding_matrix, dtype=torch.float32) ) self.embedding.weight.requires_grad = False # Do not train the embedding layer self.lstm = nn.LSTM( input_size=embedding_matrix.shape[1], hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, ) self.fc = nn.Linear(hidden_size, 1) def forward(self, title, text): title_emb = self.embedding(title) text_emb = self.embedding(text) combined = torch.cat((title_emb, text_emb), dim=1) output, (hidden, _) = self.lstm(combined) out = self.fc(hidden[-1]) return torch.sigmoid(out)