File size: 2,721 Bytes
9ea24fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# !/usr/bin/env python3
# -*- coding: utf-8 -*-

#Training code for the model

import logging

from sklearn.metrics import accuracy_score
import torch

from tqdm import tqdm

logger = logging.getLogger()


def train(
        train_dataloader: torch.utils.data.DataLoader,
        val_dataloader: torch.utils.data.DataLoader,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        epochs: int = 5,
        max_grad_norm: float = 10) -> None:
    """train a model on the given dataset

    Args:
        dataloader (torch.utils.data.DataLoader): training dataset
        model (torch.nn.Module): model to train
        optimizer (torch.optim.Optimizer): optimizer to use
        epochs (int, optional): number of training epochs. Defaults to 5.
        max_grad_norm (float, optional): gradient clipping. Defaults to 10.
    """

    model.train()

    for epoch in range(1, epochs + 1):
        running_loss = 0
        train_preds = []
        train_labels = []
        for _, (batch, labels) in tqdm(
                enumerate(train_dataloader),
                total=len(train_dataloader),
                desc=f"Epoch {epoch}"):

            optimizer.zero_grad()

            ids = batch['input_ids']
            mask = batch['attention_mask']
            token_type_ids = batch['token_type_ids']

            out = model(
                input_ids=ids,
                attention_mask=mask,
                token_type_ids=token_type_ids,
                labels=labels
            )

            loss = out.loss
            train_preds.extend(out.logits.argmax(-1))
            train_labels.extend(labels)

            # clip gradients for stability
            torch.nn.utils.clip_grad_norm_(
                parameters=model.parameters(), max_norm=max_grad_norm
            )

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        logger.info("Epoch Train Accuracy %.4f",
                    accuracy_score(train_preds, train_labels)
                    )

        test_preds = []
        test_labels = []
        for _, (batch, labels) in enumerate(val_dataloader):
            ids = batch['input_ids']
            mask = batch['attention_mask']
            token_type_ids = batch['token_type_ids']

            pred = model(
                input_ids=ids,
                attention_mask=mask,
                token_type_ids=token_type_ids
            )
            test_preds.extend(pred.logits.argmax(-1))
            test_labels.extend(labels)
        logger.info("Test Accuracy %.4f",
                    accuracy_score(test_preds, test_labels)
                    )

        logger.info("Epoch %d, Loss %.4f", epoch, running_loss)