import torch import torchvision from tqdm.auto import tqdm from torch import nn device = "cuda" if torch.cuda.is_available() else "cpu" def train_step(model: torch.nn.Module, dataloader, optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module, device: torch.device): model.to(device) model.train() for X, y in dataloader: X, y = X.to(device), y.to(device) # 1. Forward pass logits = model(X) # 2. Calculate loss loss = loss_fn(logits, y) # 3. Optimizer zero grad optimizer.zero_grad() # 4. Loss backward loss.backward() # 5. Optimizer step optimizer.step() def train(model: torch.nn.Module, train_dataloader, test_dataloader, optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module, epochs: int=10): for epoch in tqdm(range(epochs)): train_step(model, train_dataloader, optimizer, loss_fn, device)