udyan2's picture
Upload 3 files
9ea24fa
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#Training code for the model
import logging
from sklearn.metrics import accuracy_score
import torch
from tqdm import tqdm
logger = logging.getLogger()
def train(
train_dataloader: torch.utils.data.DataLoader,
val_dataloader: torch.utils.data.DataLoader,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
epochs: int = 5,
max_grad_norm: float = 10) -> None:
"""train a model on the given dataset
Args:
dataloader (torch.utils.data.DataLoader): training dataset
model (torch.nn.Module): model to train
optimizer (torch.optim.Optimizer): optimizer to use
epochs (int, optional): number of training epochs. Defaults to 5.
max_grad_norm (float, optional): gradient clipping. Defaults to 10.
"""
model.train()
for epoch in range(1, epochs + 1):
running_loss = 0
train_preds = []
train_labels = []
for _, (batch, labels) in tqdm(
enumerate(train_dataloader),
total=len(train_dataloader),
desc=f"Epoch {epoch}"):
optimizer.zero_grad()
ids = batch['input_ids']
mask = batch['attention_mask']
token_type_ids = batch['token_type_ids']
out = model(
input_ids=ids,
attention_mask=mask,
token_type_ids=token_type_ids,
labels=labels
)
loss = out.loss
train_preds.extend(out.logits.argmax(-1))
train_labels.extend(labels)
# clip gradients for stability
torch.nn.utils.clip_grad_norm_(
parameters=model.parameters(), max_norm=max_grad_norm
)
loss.backward()
optimizer.step()
running_loss += loss.item()
logger.info("Epoch Train Accuracy %.4f",
accuracy_score(train_preds, train_labels)
)
test_preds = []
test_labels = []
for _, (batch, labels) in enumerate(val_dataloader):
ids = batch['input_ids']
mask = batch['attention_mask']
token_type_ids = batch['token_type_ids']
pred = model(
input_ids=ids,
attention_mask=mask,
token_type_ids=token_type_ids
)
test_preds.extend(pred.logits.argmax(-1))
test_labels.extend(labels)
logger.info("Test Accuracy %.4f",
accuracy_score(test_preds, test_labels)
)
logger.info("Epoch %d, Loss %.4f", epoch, running_loss)