File size: 3,495 Bytes
c1c5bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys

import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
import torch


from util.metrics import perplexity


class BloomTrainer:
    DEFAULT_VAL_FREQ = 5
    ITERATION_LIMIT = 150

    def __init__(self, model, config, train_dataset, val_dataset, wandb_run=None, prompt_path=None, val_freq=None):
        self.model = model
        self.config = config
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.wandb_run = wandb_run
        self.val_freq = val_freq
        if self.val_freq is None:
            self.val_freq = self.DEFAULT_VAL_FREQ
        self.prompt_path = prompt_path

        self.best_loss = np.inf

        self.train_loader = DataLoader(self.train_dataset,
                                       shuffle=True,
                                       batch_size=config.BATCH_SIZE,
                                       drop_last=True)
        self.val_loader = DataLoader(self.val_dataset,
                                     shuffle=True,
                                     batch_size=config.BATCH_SIZE,
                                     drop_last=False)

        self.optimizer = AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)

        self.lr_scheduler = get_scheduler(
            name="linear",
            optimizer=self.optimizer,
            num_warmup_steps=0,
            num_training_steps= len(self.train_loader) * self.config.N_EPOCH
        )

    def train(self):
        self.model.train()
        iter_counter = 0
        for epoch in range(self.config.N_EPOCH):
            for batch in self.train_loader:
                batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
                         'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
                outputs = self.model(**batch)
                loss = outputs.loss
                loss.backward()
                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
                self.wandb_run.log({'loss': loss})
                iter_counter += 1
                if (iter_counter + 1) % self.val_freq == 0:
                    eval_perplexity = self.evaluate(perplexity)
                    self.wandb_run.log({'perplexity': eval_perplexity})
                    if loss.item() < self.best_loss:
                        self.best_loss = loss.item()
                        self.save_model(self.prompt_path)
                        print('Model saved')
                if iter_counter >= self.ITERATION_LIMIT:
                    return

    def evaluate(self, eval_fn):
        logits = []
        labels = []
        self.model.eval()
        with torch.no_grad():
            for batch in self.val_loader:
                batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
                         'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
                outputs = self.model(**batch)
                labels.extend(batch['input_ids'])
                logits.extend(outputs.logits)
        metric = eval_fn(logits, labels)
        return metric

    def save_model(self, path):
        torch.save(self.model.transformer.prompt_embeddings.state_dict(), path)

    def load_model(self, path):
        self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))