j.gilyazev
add personalized-chat-bot
c1c5bd9
import sys
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
import torch
from util.metrics import perplexity
class BloomTrainer:
DEFAULT_VAL_FREQ = 5
ITERATION_LIMIT = 150
def __init__(self, model, config, train_dataset, val_dataset, wandb_run=None, prompt_path=None, val_freq=None):
self.model = model
self.config = config
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.wandb_run = wandb_run
self.val_freq = val_freq
if self.val_freq is None:
self.val_freq = self.DEFAULT_VAL_FREQ
self.prompt_path = prompt_path
self.best_loss = np.inf
self.train_loader = DataLoader(self.train_dataset,
shuffle=True,
batch_size=config.BATCH_SIZE,
drop_last=True)
self.val_loader = DataLoader(self.val_dataset,
shuffle=True,
batch_size=config.BATCH_SIZE,
drop_last=False)
self.optimizer = AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
self.lr_scheduler = get_scheduler(
name="linear",
optimizer=self.optimizer,
num_warmup_steps=0,
num_training_steps= len(self.train_loader) * self.config.N_EPOCH
)
def train(self):
self.model.train()
iter_counter = 0
for epoch in range(self.config.N_EPOCH):
for batch in self.train_loader:
batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
outputs = self.model(**batch)
loss = outputs.loss
loss.backward()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.wandb_run.log({'loss': loss})
iter_counter += 1
if (iter_counter + 1) % self.val_freq == 0:
eval_perplexity = self.evaluate(perplexity)
self.wandb_run.log({'perplexity': eval_perplexity})
if loss.item() < self.best_loss:
self.best_loss = loss.item()
self.save_model(self.prompt_path)
print('Model saved')
if iter_counter >= self.ITERATION_LIMIT:
return
def evaluate(self, eval_fn):
logits = []
labels = []
self.model.eval()
with torch.no_grad():
for batch in self.val_loader:
batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
outputs = self.model(**batch)
labels.extend(batch['input_ids'])
logits.extend(outputs.logits)
metric = eval_fn(logits, labels)
return metric
def save_model(self, path):
torch.save(self.model.transformer.prompt_embeddings.state_dict(), path)
def load_model(self, path):
self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))