import torch from torch import nn from tqdm.auto import tqdm def train_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device): model.train() train_loss, train_acc = 0, 0 for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) y_pred = model(X) y = y.unsqueeze(dim = 1).float() loss = loss_fn(y_pred, y) train_loss = train_loss + loss.item() optimizer.zero_grad() loss.backward() optimizer.step() y_pred_class = torch.sigmoid(y_pred) acc = (y_pred_class == y).sum().item() / len(y_pred) train_acc = train_acc + acc train_loss = train_loss / len(dataloader) train_acc = train_acc / len(dataloader) return train_loss, train_acc def test_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, device: torch.device): model.eval() test_loss, test_acc = 0, 0 with torch.inference_mode(): for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) y_pred = model(X) y = y.unsqueeze(dim = 1).float() loss = loss_fn(y_pred, y) test_loss = test_loss + loss.item() y_pred_class = y_pred.sigmoid() acc = (y_pred_class == y).sum().item() / len(y_pred) test_acc = test_acc + acc test_loss = test_loss / len(dataloader) test_acc = test_acc / len(dataloader) return test_loss, test_acc def train(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module, epochs: int, device: torch.device, writer: torch.utils.tensorboard.SummaryWriter): results = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []} model.to(device) # loss_fn = nn.CrossEntropyLoss() # optimizer = torch.optim.Adam(model.parameters(),lr = 0.01) for epoch in tqdm(range(epochs)): train_loss, train_acc = train_step(model = model, dataloader = train_dataloader, loss_fn = loss_fn, optimizer = optimizer, device = device) test_loss, test_acc = test_step(model = model, dataloader = test_dataloader, loss_fn = loss_fn, device = device) print( f"| Epoch: {epoch+1} | " f"train_loss: {train_loss:.4f} | " f"train_acc: {train_loss:.4f} | " f"test_loss: {test_loss:.4f} | " f"test_acc: {test_loss:.4f} |" ) results['train_loss'].append(train_loss) results['train_acc'].append(train_acc) results['test_loss'].append(test_loss) results['test_acc'].append(test_acc) writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train_loss": train_loss, "test_loss": test_loss}, global_step=epoch) # Add accuracy results to SummaryWriter writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train_acc": train_acc, "test_acc": test_acc}, global_step=epoch) # Track the PyTorch model architecture writer.add_graph(model=model, # Pass in an example input input_to_model=torch.randn(32, 3, 224, 224).to(device)) # Close the writer writer.close() return results