bert / trainer.py
pt-sk's picture
Upload 7 files
28dc58b verified
raw
history blame contribute delete
No virus
3.68 kB
import torch
import torch.nn as nn
import torch.nn. functional as F
from optimizer_schedule import ScheduledOptim
import tqdm
from torch.optim import Adam
class BERTTrainer:
def __init__(
self,
model,
train_dataloader,
test_dataloader=None,
lr= 1e-4,
weight_decay=0.01,
betas=(0.9, 0.999),
warmup_steps=10000,
log_freq=10,
device='cuda'
):
self.device = device
self.model = model
self.train_data = train_dataloader
self.test_data = test_dataloader
# Setting the Adam optimizer with hyper-param
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
self.optim_schedule = ScheduledOptim(
self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
)
# Using Negative Log Likelihood Loss function for predicting the masked_token
self.criterion = torch.nn.NLLLoss(ignore_index=0)
self.log_freq = log_freq
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
def train(self, epoch):
self.iteration(epoch, self.train_data)
def test(self, epoch):
self.iteration(epoch, self.test_data, train=False)
def iteration(self, epoch, data_loader, train=True):
avg_loss = 0.0
total_correct = 0
total_element = 0
mode = "train" if train else "test"
# progress bar
data_iter = tqdm.tqdm(
enumerate(data_loader),
desc="EP_%s:%d" % (mode, epoch),
total=len(data_loader),
bar_format="{l_bar}{r_bar}"
)
for i, data in data_iter:
# 0. batch_data will be sent into the device(GPU or cpu)
data = {key: value.to(self.device) for key, value in data.items()}
# 1. forward the next_sentence_prediction and masked_lm model
next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
# 2-1. NLL(negative log likelihood) loss of is_next classification result
next_loss = self.criterion(next_sent_output, data["is_next"])
# 2-2. NLLLoss of predicting masked token word
# transpose to (m, vocab_size, seq_len) vs (m, seq_len)
# criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
loss = next_loss + mask_loss
# 3. backward and optimization only in train
if train:
self.optim_schedule.zero_grad()
loss.backward()
self.optim_schedule.step_and_update_lr()
# next sentence prediction accuracy
correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
avg_loss += loss.item()
total_correct += correct
total_element += data["is_next"].nelement()
post_fix = {
"epoch": epoch,
"iter": i,
"avg_loss": avg_loss / (i + 1),
"avg_acc": total_correct / total_element * 100,
"loss": loss.item()
}
if i % self.log_freq == 0:
data_iter.write(str(post_fix))
print(
f"EP{epoch}, {mode}: \
avg_loss={avg_loss / len(data_iter)}, \
total_acc={total_correct * 100.0 / total_element}"
)