| | import torch
|
| | import torch.nn as nn
|
| | from tqdm.auto import tqdm
|
| |
|
| | def train_step(model: torch.nn.Module,
|
| | dataloader: torch.utils.data.DataLoader,
|
| | loss_fn: torch.nn.Module,
|
| | optimizer: torch.optim.Optimizer,
|
| | device: torch.device):
|
| | """
|
| | Melakukan satu epoch training.
|
| |
|
| | Mengatur model ke mode training, melakukan forward pass,
|
| | menghitung loss, melakukan backpropagation, dan update weights.
|
| | """
|
| |
|
| |
|
| | model.train()
|
| |
|
| |
|
| | train_loss, train_acc = 0, 0
|
| |
|
| |
|
| |
|
| | for X, y in tqdm(dataloader, desc="Training"):
|
| |
|
| | X, y = X.to(device), y.to(device)
|
| |
|
| |
|
| | y_pred_logits = model(X)
|
| |
|
| |
|
| | loss = loss_fn(y_pred_logits, y)
|
| | train_loss += loss.item()
|
| |
|
| |
|
| | optimizer.zero_grad()
|
| |
|
| |
|
| | loss.backward()
|
| |
|
| |
|
| | optimizer.step()
|
| |
|
| |
|
| |
|
| | y_pred_class = torch.argmax(y_pred_logits, dim=1)
|
| | train_acc += (y_pred_class == y).sum().item() / len(y_pred_logits)
|
| |
|
| |
|
| | train_loss = train_loss / len(dataloader)
|
| | train_acc = train_acc / len(dataloader)
|
| |
|
| | return train_loss, train_acc
|
| |
|
| | def val_step(model: torch.nn.Module,
|
| | dataloader: torch.utils.data.DataLoader,
|
| | loss_fn: torch.nn.Module,
|
| | device: torch.device):
|
| | """
|
| | Melakukan satu epoch validasi.
|
| |
|
| | Mengatur model ke mode evaluasi, melakukan forward pass,
|
| | dan menghitung loss/akurasi. Tidak ada backpropagation.
|
| | """
|
| |
|
| |
|
| | model.eval()
|
| |
|
| |
|
| | val_loss, val_acc = 0, 0
|
| |
|
| |
|
| |
|
| | with torch.no_grad():
|
| |
|
| | for X, y in tqdm(dataloader, desc="Validasi"):
|
| |
|
| | X, y = X.to(device), y.to(device)
|
| |
|
| |
|
| | y_pred_logits = model(X)
|
| |
|
| |
|
| | loss = loss_fn(y_pred_logits, y)
|
| | val_loss += loss.item()
|
| |
|
| |
|
| | y_pred_class = torch.argmax(y_pred_logits, dim=1)
|
| | val_acc += (y_pred_class == y).sum().item() / len(y_pred_logits)
|
| |
|
| |
|
| | val_loss = val_loss / len(dataloader)
|
| | val_acc = val_acc / len(dataloader)
|
| |
|
| | return val_loss, val_acc |