P5_yelp_small / pretrain_model.py
makitanikaze's picture
Upload P5Pretraining
cf42adb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from modeling_p5 import P5
class P5Pretraining(P5):
def __init__(self, config):
super().__init__(config)
self.losses = self.config.losses.split(',')
def train_step(self, batch):
device = next(self.parameters()).device
input_ids = batch['input_ids'].to(device)
whole_word_ids = batch['whole_word_ids'].to(device)
lm_labels = batch["target_ids"].to(device)
loss_weights = batch["loss_weights"].to(device)
output = self(
input_ids=input_ids,
whole_word_ids=whole_word_ids,
labels=lm_labels,
return_dict=True
)
assert 'loss' in output
lm_mask = lm_labels != -100
lm_mask = lm_mask.float()
B, L = lm_labels.size()
loss = output['loss']
loss = loss.view(B, L) * lm_mask
loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)
task_counts = {task: 0 for task in self.losses}
task_loss = {task: 0 for task in self.losses}
results = {}
results['loss'] = (loss * loss_weights).mean()
results['total_loss'] = loss.detach().sum()
results['total_loss_count'] = len(loss)
task_counts = {task: 0 for task in self.losses}
task_loss = {task: 0 for task in self.losses}
for _loss, task in zip(loss.detach(), batch['task']):
task_loss[task] += _loss
task_counts[task] += 1
for task in self.losses:
if task_counts[task] > 0:
results[f'{task}_loss'] = task_loss[task]
results[f'{task}_loss_count'] = task_counts[task]
return results
@torch.no_grad()
def valid_step(self, batch):
self.eval()
device = next(self.parameters()).device
input_ids = batch['input_ids'].to(device)
lm_labels = batch["target_ids"].to(device)
loss_weights = batch["loss_weights"].to(device)
output = self(
input_ids=input_ids,
labels=lm_labels,
return_dict=True
)
assert 'loss' in output
lm_mask = lm_labels != -100
lm_mask = lm_mask.float()
B, L = lm_labels.size()
loss = output['loss']
loss = loss.view(B, L) * lm_mask
loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)
results = {}
results['loss'] = (loss * loss_weights).mean()
results['total_loss'] = loss.detach().sum()
results['total_loss_count'] = len(loss)
task_counts = {task: 0 for task in self.losses}
task_loss = {task: 0 for task in self.losses}
for _loss, task in zip(loss.detach(), batch['task']):
task_loss[task] += _loss
task_counts[task] += 1
for task in self.losses:
if task_counts[task] > 0:
results[f'{task}_loss'] = task_loss[task]
results[f'{task}_loss_count'] = task_counts[task]
if 'rating' in self.losses:
output = self.generate(
input_ids=input_ids
)
generated_score = self.tokenizer.batch_decode(output, skip_special_tokens=True)
results['rating_pred'] = generated_score
return results
@torch.no_grad()
def generate_step(self, batch):
self.eval()
device = next(self.parameters()).device
input_ids = batch['input_ids'].to(device)
output = self.generate(
input_ids=input_ids,
)
generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)
return generated_sents