tamilatis / tamilatis /trainer.py
seanbenhur's picture
update
b6283c9
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torchmetrics.functional import accuracy, f1_score, precision, recall
from tqdm import tqdm, trange
logger = logging.getLogger(__name__)
class ATISTrainer:
"""A Trainer class consists of utitlity functions for training the model"""
def __init__(
self,
model,
optimizer,
scheduler,
criterion,
accelerate,
output_dir,
num_labels,
num_intents,
run
):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = scheduler
self.accelerator = accelerate
self.output_dir = output_dir
self.num_labels = num_labels
self.num_intents = num_intents
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
self.run = run
logging.info(f"Strating Training, outputs are saved in {self.output_dir}")
def train_step(self, iterator):
training_progress_bar = tqdm(iterator, desc="training")
for batch in training_progress_bar:
input_ids, attention_mask, labels, intents = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
batch["intent"],
)
self.optimizer.zero_grad()
loss_dict = self.model(input_ids, attention_mask, labels)
slot_logits, intent_logits, slot_loss = (
loss_dict["dst_logits"],
loss_dict["intent_loss"],
loss_dict["dst_loss"],
)
# compute training accuracy for slots
flattened_target_labels = batch["labels"].view(
-1
) # [batch_size * seq_len, ]
active_logits = slot_logits.view(
-1, self.num_labels
) # [batch_size* seq_len, num_labels]
flattened_preds = torch.argmax(
active_logits, axis=-1
) # [batch_size * seq_len,]
# compute accuracy at active labels
active_accuracy = (
batch["labels"].view(-1) != -100
) # [batch_size * seq_len, ]
slot_labels = torch.masked_select(flattened_target_labels, active_accuracy)
slot_preds = torch.masked_select(flattened_preds, active_accuracy)
# compute loss for intents
#use rlw
intent_loss = self.criterion(intent_logits, batch["intent"])
weight = F.softmax(torch.randn(1), dim=-1) # RLW is only this!
intent_loss = torch.sum(intent_loss*weight.cuda())
intent_preds = torch.argmax(intent_logits, axis=1)
train_loss = slot_loss + intent_loss
self.accelerator.backward(train_loss)
self.optimizer.step()
if self.scheduler is not None:
if not self.accelerator.optimizer_step_was_skipped:
self.scheduler.step()
if self.scheduler is not None:
self.scheduler.step()
intent_acc = accuracy(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
intent_f1 = f1_score(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
intent_rec = recall(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
intent_prec = precision(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
slot_acc = accuracy(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
slot_f1 = f1_score(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
slot_rec = recall(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
slot_prec = precision(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
self.run.log(
{
"train_loss_step": train_loss.cpu().detach().numpy(),
"train_intent_acc_step": intent_acc,
"train_intent_f1_step": intent_f1,
"train_slot_acc_step": slot_acc,
"train_slot_f1_step": slot_f1,
}
)
# logging.info({"train_loss_step": train_loss, "train_intent_acc_step": intent_acc, "train_intent_f1_step": intent_f1, "train_slot_acc_step": slot_acc, "train_slot_f1_step": slot_f1 })
return {
"train_loss_epoch": train_loss / len(iterator),
"train_intent_f1_epoch": intent_f1 / len(iterator),
"train_intent_acc_epoch": intent_acc / len(iterator),
"train_slot_f1_epoch": slot_f1 / len(iterator),
"train_slot_acc_epoch": slot_acc / len(iterator),
}
@torch.no_grad()
def eval_step(self, iterator):
eval_progress_bar = tqdm(iterator, desc="Evaluating")
for batch in eval_progress_bar:
input_ids, attention_mask, labels, intents = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
batch["intent"],
)
loss_dict = self.model(input_ids, attention_mask, labels)
slot_logits, intent_logits, slot_loss = (
loss_dict["dst_logits"],
loss_dict["intent_loss"],
loss_dict["dst_loss"],
)
# compute training accuracy for slots
flattened_target_labels = batch["labels"].view(
-1
) # [batch_size * seq_len, ]
active_logits = slot_logits.view(
-1, self.num_labels
) # [batch_size* seq_len, num_labels]
flattened_preds = torch.argmax(
active_logits, axis=-1
) # [batch_size * seq_len,]
# compute accuracy at active labels
active_accuracy = (
batch["labels"].view(-1) != -100
) # [batch_size * seq_len, ]
slot_labels = torch.masked_select(flattened_target_labels, active_accuracy)
slot_preds = torch.masked_select(flattened_preds, active_accuracy)
# compute loss for intents
intent_loss = self.criterion(intent_logits, batch["intent"])
weight = F.softmax(torch.randn(1), dim=-1) # RLW is only this!
intent_loss = torch.sum(intent_loss*weight.cuda())
intent_preds = torch.argmax(intent_logits, axis=1)
eval_loss = slot_loss + intent_loss
intent_acc = accuracy(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
intent_f1 = f1_score(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
intent_rec = recall(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
intent_prec = precision(
intent_preds, intents, num_classes=self.num_intents, average="weighted"
)
slot_acc = accuracy(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
slot_f1 = f1_score(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
slot_rec = recall(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
slot_prec = precision(
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
)
self.run.log(
{
"eval_loss_step": eval_loss,
"eval_intent_acc_step": intent_acc,
"eval_intent_f1_step": intent_f1,
"eval_slot_acc_step": slot_acc,
"eval_slot_f1_step": slot_f1,
}
)
return {
"eval_loss_epoch": eval_loss / len(iterator),
"eval_intent_f1_epoch": intent_f1 / len(iterator),
"eval_intent_acc_epoch": intent_acc / len(iterator),
"eval_slot_f1_epoch": slot_f1 / len(iterator),
"eval_slot_acc_epoch": slot_acc / len(iterator),
}
def fit(self, n_epochs, train_dataloader, eval_dataloader, patience):
best_eval_loss = float("inf")
pbar = trange(n_epochs)
for epoch in pbar:
train_metrics_dict = self.train_step(train_dataloader)
eval_metrics_dict = self.eval_step(eval_dataloader)
# access all the values from the dicts
train_loss, eval_loss = (
train_metrics_dict["train_loss_epoch"],
eval_metrics_dict["eval_loss_epoch"],
)
train_intent_f1, eval_intent_f1 = (
train_metrics_dict["train_intent_f1_epoch"],
eval_metrics_dict["eval_intent_f1_epoch"],
)
train_intent_acc, eval_intent_acc = (
train_metrics_dict["train_intent_acc_epoch"],
eval_metrics_dict["eval_intent_acc_epoch"],
)
train_slot_f1, eval_slot_f1 = (
train_metrics_dict["train_intent_acc_epoch"],
eval_metrics_dict["eval_intent_acc_epoch"],
)
train_slot_acc, eval_slot_acc = (
train_metrics_dict["train_slot_acc_epoch"],
eval_metrics_dict["eval_slot_acc_epoch"],
)
if eval_loss < best_eval_loss:
best_model = self.model
best_eval_loss = eval_loss
train_logs = {
"epoch": epoch,
"train_loss": train_loss,
"eval_loss": eval_loss,
"train_intent_acc": train_intent_acc,
"train_intent_f1": train_intent_f1,
"eval_intent_f1": eval_intent_f1,
"eval_intent_acc": eval_intent_acc,
"train_slot_f1": train_slot_f1,
"train_slot_acc": train_slot_acc,
"lr": {self.optimizer.param_groups[0]["lr"]: 0.2},
}
train_logs["patience"] = patience
logging.info(train_logs)
logging.info(eval_metrics_dict)
self.accelerator.wait_for_everyone()
model = self.accelerator.unwrap_model(self.model)
self.accelerator.save_state(self.output_dir)
logging.info(f"Checkpoint is saved in {self.output_dir}")
return best_model, best_eval_loss