|
import pytorch_lightning as pl |
|
from transformers import AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
class T5(pl.LightningModule): |
|
def __init__(self, lr=5e-5, num_train_epochs=15, warmup_steps=1000): |
|
super().__init__() |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") |
|
self.train_losses=[] |
|
self.validation_losses=[] |
|
|
|
self.train_losses_epoch=[] |
|
self.validation_losses_epoch=[] |
|
|
|
self.save_hyperparameters() |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
|
return outputs |
|
|
|
def common_step(self, batch, batch_idx): |
|
outputs = self(**batch) |
|
loss = outputs.loss |
|
|
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss = self.common_step(batch, batch_idx) |
|
|
|
|
|
self.log("training_loss", loss) |
|
self.train_losses.append(loss) |
|
|
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss = self.common_step(batch, batch_idx) |
|
self.log("validation_loss", loss, on_epoch=True) |
|
self.validation_losses.append(loss) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
return loss |
|
def on_train_epoch_end(self): |
|
|
|
avg_train_loss = sum(self.train_losses)/ len(self.train_losses) |
|
self.train_losses_epoch.append(avg_train_loss.item()) |
|
|
|
|
|
self.train_losses = [] |
|
|
|
def on_validation_epoch_end(self): |
|
|
|
avg_val_loss = sum(self.validation_losses) / len(self.validation_losses) |
|
self.validation_losses_epoch.append(avg_val_loss.item()) |
|
|
|
|
|
self.validation_losses = [] |
|
|
|
|
|
|
|
|
|
self.test_losses = [] |
|
def configure_optimizers(self): |
|
|
|
optimizer = AdamW(self.model.parameters(), lr=self.hparams.lr) |
|
|
|
num_train_optimization_steps = self.hparams.num_train_epochs * len(train_dataloader) |
|
lr_scheduler = {'scheduler': get_linear_schedule_with_warmup(optimizer, |
|
num_warmup_steps=self.hparams.warmup_steps, |
|
num_training_steps=num_train_optimization_steps), |
|
'name': 'learning_rate', |
|
'interval':'step', |
|
'frequency': 1} |
|
|
|
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} |
|
|
|
def generate(self, input_ids, max_new_tokens=30, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): |
|
input_ids=input_ids.clone().detach().reshape((1,-1)).to(device) |
|
return self.model.generate(input_ids) |
|
|
|
def push_to_hub(self, model_name, organization): |
|
|
|
self.model.push_to_hub(model_name, organization) |
|
|
|
def from_pretrained(self, model_path): |
|
AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
|
|
|
|
def train_dataloader(self): |
|
return train_dataloader |
|
|
|
def val_dataloader(self): |
|
return valid_dataloader |
|
|
|
def test_dataloader(self): |
|
return test_dataloader |
|
|