Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
from torchmetrics.functional import accuracy, f1_score, precision, recall | |
from tqdm import tqdm, trange | |
logger = logging.getLogger(__name__) | |
class ATISTrainer: | |
"""A Trainer class consists of utitlity functions for training the model""" | |
def __init__( | |
self, | |
model, | |
optimizer, | |
scheduler, | |
criterion, | |
accelerate, | |
output_dir, | |
num_labels, | |
num_intents, | |
run | |
): | |
self.model = model | |
self.criterion = criterion | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.accelerator = accelerate | |
self.output_dir = output_dir | |
self.num_labels = num_labels | |
self.num_intents = num_intents | |
if not os.path.exists(self.output_dir): | |
os.makedirs(self.output_dir) | |
self.run = run | |
logging.info(f"Strating Training, outputs are saved in {self.output_dir}") | |
def train_step(self, iterator): | |
training_progress_bar = tqdm(iterator, desc="training") | |
for batch in training_progress_bar: | |
input_ids, attention_mask, labels, intents = ( | |
batch["input_ids"], | |
batch["attention_mask"], | |
batch["labels"], | |
batch["intent"], | |
) | |
self.optimizer.zero_grad() | |
loss_dict = self.model(input_ids, attention_mask, labels) | |
slot_logits, intent_logits, slot_loss = ( | |
loss_dict["dst_logits"], | |
loss_dict["intent_loss"], | |
loss_dict["dst_loss"], | |
) | |
# compute training accuracy for slots | |
flattened_target_labels = batch["labels"].view( | |
-1 | |
) # [batch_size * seq_len, ] | |
active_logits = slot_logits.view( | |
-1, self.num_labels | |
) # [batch_size* seq_len, num_labels] | |
flattened_preds = torch.argmax( | |
active_logits, axis=-1 | |
) # [batch_size * seq_len,] | |
# compute accuracy at active labels | |
active_accuracy = ( | |
batch["labels"].view(-1) != -100 | |
) # [batch_size * seq_len, ] | |
slot_labels = torch.masked_select(flattened_target_labels, active_accuracy) | |
slot_preds = torch.masked_select(flattened_preds, active_accuracy) | |
# compute loss for intents | |
#use rlw | |
intent_loss = self.criterion(intent_logits, batch["intent"]) | |
weight = F.softmax(torch.randn(1), dim=-1) # RLW is only this! | |
intent_loss = torch.sum(intent_loss*weight.cuda()) | |
intent_preds = torch.argmax(intent_logits, axis=1) | |
train_loss = slot_loss + intent_loss | |
self.accelerator.backward(train_loss) | |
self.optimizer.step() | |
if self.scheduler is not None: | |
if not self.accelerator.optimizer_step_was_skipped: | |
self.scheduler.step() | |
if self.scheduler is not None: | |
self.scheduler.step() | |
intent_acc = accuracy( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
intent_f1 = f1_score( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
intent_rec = recall( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
intent_prec = precision( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
slot_acc = accuracy( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
slot_f1 = f1_score( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
slot_rec = recall( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
slot_prec = precision( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
self.run.log( | |
{ | |
"train_loss_step": train_loss.cpu().detach().numpy(), | |
"train_intent_acc_step": intent_acc, | |
"train_intent_f1_step": intent_f1, | |
"train_slot_acc_step": slot_acc, | |
"train_slot_f1_step": slot_f1, | |
} | |
) | |
# logging.info({"train_loss_step": train_loss, "train_intent_acc_step": intent_acc, "train_intent_f1_step": intent_f1, "train_slot_acc_step": slot_acc, "train_slot_f1_step": slot_f1 }) | |
return { | |
"train_loss_epoch": train_loss / len(iterator), | |
"train_intent_f1_epoch": intent_f1 / len(iterator), | |
"train_intent_acc_epoch": intent_acc / len(iterator), | |
"train_slot_f1_epoch": slot_f1 / len(iterator), | |
"train_slot_acc_epoch": slot_acc / len(iterator), | |
} | |
def eval_step(self, iterator): | |
eval_progress_bar = tqdm(iterator, desc="Evaluating") | |
for batch in eval_progress_bar: | |
input_ids, attention_mask, labels, intents = ( | |
batch["input_ids"], | |
batch["attention_mask"], | |
batch["labels"], | |
batch["intent"], | |
) | |
loss_dict = self.model(input_ids, attention_mask, labels) | |
slot_logits, intent_logits, slot_loss = ( | |
loss_dict["dst_logits"], | |
loss_dict["intent_loss"], | |
loss_dict["dst_loss"], | |
) | |
# compute training accuracy for slots | |
flattened_target_labels = batch["labels"].view( | |
-1 | |
) # [batch_size * seq_len, ] | |
active_logits = slot_logits.view( | |
-1, self.num_labels | |
) # [batch_size* seq_len, num_labels] | |
flattened_preds = torch.argmax( | |
active_logits, axis=-1 | |
) # [batch_size * seq_len,] | |
# compute accuracy at active labels | |
active_accuracy = ( | |
batch["labels"].view(-1) != -100 | |
) # [batch_size * seq_len, ] | |
slot_labels = torch.masked_select(flattened_target_labels, active_accuracy) | |
slot_preds = torch.masked_select(flattened_preds, active_accuracy) | |
# compute loss for intents | |
intent_loss = self.criterion(intent_logits, batch["intent"]) | |
weight = F.softmax(torch.randn(1), dim=-1) # RLW is only this! | |
intent_loss = torch.sum(intent_loss*weight.cuda()) | |
intent_preds = torch.argmax(intent_logits, axis=1) | |
eval_loss = slot_loss + intent_loss | |
intent_acc = accuracy( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
intent_f1 = f1_score( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
intent_rec = recall( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
intent_prec = precision( | |
intent_preds, intents, num_classes=self.num_intents, average="weighted" | |
) | |
slot_acc = accuracy( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
slot_f1 = f1_score( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
slot_rec = recall( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
slot_prec = precision( | |
slot_preds, slot_labels, num_classes=self.num_labels, average="weighted" | |
) | |
self.run.log( | |
{ | |
"eval_loss_step": eval_loss, | |
"eval_intent_acc_step": intent_acc, | |
"eval_intent_f1_step": intent_f1, | |
"eval_slot_acc_step": slot_acc, | |
"eval_slot_f1_step": slot_f1, | |
} | |
) | |
return { | |
"eval_loss_epoch": eval_loss / len(iterator), | |
"eval_intent_f1_epoch": intent_f1 / len(iterator), | |
"eval_intent_acc_epoch": intent_acc / len(iterator), | |
"eval_slot_f1_epoch": slot_f1 / len(iterator), | |
"eval_slot_acc_epoch": slot_acc / len(iterator), | |
} | |
def fit(self, n_epochs, train_dataloader, eval_dataloader, patience): | |
best_eval_loss = float("inf") | |
pbar = trange(n_epochs) | |
for epoch in pbar: | |
train_metrics_dict = self.train_step(train_dataloader) | |
eval_metrics_dict = self.eval_step(eval_dataloader) | |
# access all the values from the dicts | |
train_loss, eval_loss = ( | |
train_metrics_dict["train_loss_epoch"], | |
eval_metrics_dict["eval_loss_epoch"], | |
) | |
train_intent_f1, eval_intent_f1 = ( | |
train_metrics_dict["train_intent_f1_epoch"], | |
eval_metrics_dict["eval_intent_f1_epoch"], | |
) | |
train_intent_acc, eval_intent_acc = ( | |
train_metrics_dict["train_intent_acc_epoch"], | |
eval_metrics_dict["eval_intent_acc_epoch"], | |
) | |
train_slot_f1, eval_slot_f1 = ( | |
train_metrics_dict["train_intent_acc_epoch"], | |
eval_metrics_dict["eval_intent_acc_epoch"], | |
) | |
train_slot_acc, eval_slot_acc = ( | |
train_metrics_dict["train_slot_acc_epoch"], | |
eval_metrics_dict["eval_slot_acc_epoch"], | |
) | |
if eval_loss < best_eval_loss: | |
best_model = self.model | |
best_eval_loss = eval_loss | |
train_logs = { | |
"epoch": epoch, | |
"train_loss": train_loss, | |
"eval_loss": eval_loss, | |
"train_intent_acc": train_intent_acc, | |
"train_intent_f1": train_intent_f1, | |
"eval_intent_f1": eval_intent_f1, | |
"eval_intent_acc": eval_intent_acc, | |
"train_slot_f1": train_slot_f1, | |
"train_slot_acc": train_slot_acc, | |
"lr": {self.optimizer.param_groups[0]["lr"]: 0.2}, | |
} | |
train_logs["patience"] = patience | |
logging.info(train_logs) | |
logging.info(eval_metrics_dict) | |
self.accelerator.wait_for_everyone() | |
model = self.accelerator.unwrap_model(self.model) | |
self.accelerator.save_state(self.output_dir) | |
logging.info(f"Checkpoint is saved in {self.output_dir}") | |
return best_model, best_eval_loss | |