|
import torch |
|
import torch.nn as nn |
|
import torch.nn. functional as F |
|
from optimizer_schedule import ScheduledOptim |
|
import tqdm |
|
from torch.optim import Adam |
|
|
|
|
|
class BERTTrainer: |
|
def __init__( |
|
self, |
|
model, |
|
train_dataloader, |
|
test_dataloader=None, |
|
lr= 1e-4, |
|
weight_decay=0.01, |
|
betas=(0.9, 0.999), |
|
warmup_steps=10000, |
|
log_freq=10, |
|
device='cuda' |
|
): |
|
|
|
self.device = device |
|
self.model = model |
|
self.train_data = train_dataloader |
|
self.test_data = test_dataloader |
|
|
|
|
|
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) |
|
self.optim_schedule = ScheduledOptim( |
|
self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps |
|
) |
|
|
|
|
|
self.criterion = torch.nn.NLLLoss(ignore_index=0) |
|
self.log_freq = log_freq |
|
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) |
|
|
|
def train(self, epoch): |
|
self.iteration(epoch, self.train_data) |
|
|
|
def test(self, epoch): |
|
self.iteration(epoch, self.test_data, train=False) |
|
|
|
def iteration(self, epoch, data_loader, train=True): |
|
|
|
avg_loss = 0.0 |
|
total_correct = 0 |
|
total_element = 0 |
|
|
|
mode = "train" if train else "test" |
|
|
|
|
|
data_iter = tqdm.tqdm( |
|
enumerate(data_loader), |
|
desc="EP_%s:%d" % (mode, epoch), |
|
total=len(data_loader), |
|
bar_format="{l_bar}{r_bar}" |
|
) |
|
|
|
for i, data in data_iter: |
|
|
|
|
|
data = {key: value.to(self.device) for key, value in data.items()} |
|
|
|
|
|
next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) |
|
|
|
|
|
next_loss = self.criterion(next_sent_output, data["is_next"]) |
|
|
|
|
|
|
|
|
|
mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) |
|
|
|
|
|
loss = next_loss + mask_loss |
|
|
|
|
|
if train: |
|
self.optim_schedule.zero_grad() |
|
loss.backward() |
|
self.optim_schedule.step_and_update_lr() |
|
|
|
|
|
correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item() |
|
avg_loss += loss.item() |
|
total_correct += correct |
|
total_element += data["is_next"].nelement() |
|
|
|
post_fix = { |
|
"epoch": epoch, |
|
"iter": i, |
|
"avg_loss": avg_loss / (i + 1), |
|
"avg_acc": total_correct / total_element * 100, |
|
"loss": loss.item() |
|
} |
|
|
|
if i % self.log_freq == 0: |
|
data_iter.write(str(post_fix)) |
|
print( |
|
f"EP{epoch}, {mode}: \ |
|
avg_loss={avg_loss / len(data_iter)}, \ |
|
total_acc={total_correct * 100.0 / total_element}" |
|
) |