# !/usr/bin/env python3 # -*- coding: utf-8 -*- #Training code for the model import logging from sklearn.metrics import accuracy_score import torch from tqdm import tqdm logger = logging.getLogger() def train( train_dataloader: torch.utils.data.DataLoader, val_dataloader: torch.utils.data.DataLoader, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epochs: int = 5, max_grad_norm: float = 10) -> None: """train a model on the given dataset Args: dataloader (torch.utils.data.DataLoader): training dataset model (torch.nn.Module): model to train optimizer (torch.optim.Optimizer): optimizer to use epochs (int, optional): number of training epochs. Defaults to 5. max_grad_norm (float, optional): gradient clipping. Defaults to 10. """ model.train() for epoch in range(1, epochs + 1): running_loss = 0 train_preds = [] train_labels = [] for _, (batch, labels) in tqdm( enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch}"): optimizer.zero_grad() ids = batch['input_ids'] mask = batch['attention_mask'] token_type_ids = batch['token_type_ids'] out = model( input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids, labels=labels ) loss = out.loss train_preds.extend(out.logits.argmax(-1)) train_labels.extend(labels) # clip gradients for stability torch.nn.utils.clip_grad_norm_( parameters=model.parameters(), max_norm=max_grad_norm ) loss.backward() optimizer.step() running_loss += loss.item() logger.info("Epoch Train Accuracy %.4f", accuracy_score(train_preds, train_labels) ) test_preds = [] test_labels = [] for _, (batch, labels) in enumerate(val_dataloader): ids = batch['input_ids'] mask = batch['attention_mask'] token_type_ids = batch['token_type_ids'] pred = model( input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids ) test_preds.extend(pred.logits.argmax(-1)) test_labels.extend(labels) logger.info("Test Accuracy %.4f", accuracy_score(test_preds, test_labels) ) logger.info("Epoch %d, Loss %.4f", epoch, running_loss)