NextWordpredictor / model.py
prashantdubeypng
Deploy word prediction FastAPI backend
858f78f
Raw
History Blame Contribute Delete
663 Bytes
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, pad_token_id):
super().__init__()
self.embedding = nn.Embedding(
vocab_size,
embed_dim,
padding_idx=pad_token_id
)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
batch_first=True,
)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, input_ids):
emb = self.embedding(input_ids)
out, _ = self.lstm(emb)
logits = self.fc(out)
return logits