| import torch | |
| import torch.nn as nn | |
| class LSTMModel(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim, pad_token_id): | |
| super().__init__() | |
| self.embedding = nn.Embedding( | |
| vocab_size, | |
| embed_dim, | |
| padding_idx=pad_token_id | |
| ) | |
| self.lstm = nn.LSTM( | |
| input_size=embed_dim, | |
| hidden_size=hidden_dim, | |
| batch_first=True, | |
| ) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| def forward(self, input_ids): | |
| emb = self.embedding(input_ids) | |
| out, _ = self.lstm(emb) | |
| logits = self.fc(out) | |
| return logits | |