YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
class LSTMModel(nn.Module): def init(self, vocab_size, n_embd, n_hidden, block_size, dropout): super(LSTMModel, self).init() self.embedding = nn.Embedding(vocab_size, n_embd) self.lstm = nn.LSTM(n_embd, n_hidden, batch_first=True) self.fc = nn.Linear(n_hidden, vocab_size) self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.fc(self.dropout(x))
return x
Model, optimizer, and loss function initialization
vocab_size = len(word_to_idx) model = LSTMModel(vocab_size, n_embd, n_hidden, block_size, dropout).to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) loss_fn = nn.CrossEntropyLoss()
training_losses = [] validation_losses = []
Train and Validation Function
def train_model(): for epoch in range(max_iters): model.train() total_train_loss = 0 for batch_idx, (x, y) in enumerate(train_loader): x = x.to(device) y = y.to(device)
# Forward pass
logits = model(x)
logits = logits.view(-1, vocab_size)
y = y.view(-1)
# Compute loss
loss = loss_fn(logits, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_loader)
# Validation
model.eval()
total_val_loss = 0
with torch.no_grad():
for val_x, val_y in val_loader:
val_x = val_x.to(device)
val_y = val_y.to(device)
val_logits = model(val_x)
val_logits = val_logits.view(-1, vocab_size)
val_y = val_y.view(-1)
val_loss = loss_fn(val_logits, val_y)
total_val_loss += val_loss.item()
avg_val_loss = total_val_loss / len(val_loader)
# Append loss values to the lists
training_losses.append(avg_train_loss)
validation_losses.append(avg_val_loss)
print(f'Epoch {epoch + 1}/{max_iters}, Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}')
Step 1: Train the model and collect losses
train_model()