import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from modeling_p5 import P5 class P5Pretraining(P5): def __init__(self, config): super().__init__(config) self.losses = self.config.losses.split(',') def train_step(self, batch): device = next(self.parameters()).device input_ids = batch['input_ids'].to(device) whole_word_ids = batch['whole_word_ids'].to(device) lm_labels = batch["target_ids"].to(device) loss_weights = batch["loss_weights"].to(device) output = self( input_ids=input_ids, whole_word_ids=whole_word_ids, labels=lm_labels, return_dict=True ) assert 'loss' in output lm_mask = lm_labels != -100 lm_mask = lm_mask.float() B, L = lm_labels.size() loss = output['loss'] loss = loss.view(B, L) * lm_mask loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1) task_counts = {task: 0 for task in self.losses} task_loss = {task: 0 for task in self.losses} results = {} results['loss'] = (loss * loss_weights).mean() results['total_loss'] = loss.detach().sum() results['total_loss_count'] = len(loss) task_counts = {task: 0 for task in self.losses} task_loss = {task: 0 for task in self.losses} for _loss, task in zip(loss.detach(), batch['task']): task_loss[task] += _loss task_counts[task] += 1 for task in self.losses: if task_counts[task] > 0: results[f'{task}_loss'] = task_loss[task] results[f'{task}_loss_count'] = task_counts[task] return results @torch.no_grad() def valid_step(self, batch): self.eval() device = next(self.parameters()).device input_ids = batch['input_ids'].to(device) lm_labels = batch["target_ids"].to(device) loss_weights = batch["loss_weights"].to(device) output = self( input_ids=input_ids, labels=lm_labels, return_dict=True ) assert 'loss' in output lm_mask = lm_labels != -100 lm_mask = lm_mask.float() B, L = lm_labels.size() loss = output['loss'] loss = loss.view(B, L) * lm_mask loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1) results = {} results['loss'] = (loss * loss_weights).mean() results['total_loss'] = loss.detach().sum() results['total_loss_count'] = len(loss) task_counts = {task: 0 for task in self.losses} task_loss = {task: 0 for task in self.losses} for _loss, task in zip(loss.detach(), batch['task']): task_loss[task] += _loss task_counts[task] += 1 for task in self.losses: if task_counts[task] > 0: results[f'{task}_loss'] = task_loss[task] results[f'{task}_loss_count'] = task_counts[task] if 'rating' in self.losses: output = self.generate( input_ids=input_ids ) generated_score = self.tokenizer.batch_decode(output, skip_special_tokens=True) results['rating_pred'] = generated_score return results @torch.no_grad() def generate_step(self, batch): self.eval() device = next(self.parameters()).device input_ids = batch['input_ids'].to(device) output = self.generate( input_ids=input_ids, ) generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True) return generated_sents