|
import torch |
|
from torch import nn |
|
import lightning as L |
|
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR |
|
import numpy as np |
|
|
|
class CustomFinetuneModel(L.LightningModule): |
|
def __init__( |
|
self, |
|
model_mtr, |
|
steps_per_epoch, |
|
warmup_epochs, |
|
max_epochs, |
|
learning_rate, |
|
linear_param:int=64, |
|
use_freeze:bool=True, |
|
*args, **kwargs |
|
): |
|
super(CustomFinetuneModel, self).__init__() |
|
|
|
|
|
self.model_mtr = model_mtr |
|
if use_freeze: |
|
self.model_mtr.freeze() |
|
|
|
|
|
|
|
|
|
self.steps_per_epoch = steps_per_epoch |
|
self.warmup_epochs = warmup_epochs |
|
self.max_epochs = max_epochs |
|
self.learning_rate = learning_rate |
|
|
|
self.list_val_loss = list() |
|
|
|
self.gelu = nn.GELU() |
|
self.linear1 = nn.Linear(self.model_mtr.num_labels, linear_param) |
|
self.linear2 = nn.Linear(linear_param, linear_param) |
|
self.regression = nn.Linear(linear_param, 5) |
|
|
|
self.loss_fn = nn.L1Loss() |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
x = self.model_mtr(input_ids=input_ids, attention_mask=attention_mask) |
|
x = self.gelu(x) |
|
x = self.linear1(x) |
|
x = self.gelu(x) |
|
x = self.linear2(x) |
|
x = self.gelu(x) |
|
x = self.regression(x) |
|
|
|
return x |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx) |
|
|
|
self.log_dict( |
|
{ |
|
"train_loss": loss, |
|
}, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
|
|
) |
|
|
|
return {"loss": loss, "logits": logits, "labels": labels} |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx) |
|
|
|
self.log("val_loss", loss, sync_dist=True) |
|
|
|
return loss |
|
|
|
def valid_epoch_end(self, outputs): |
|
|
|
scores = torch.cat([x["logits"] for x in outputs]) |
|
labels = torch.cat([x["labels"] for x in outputs]) |
|
self.list_val_loss.append(self.loss_fn(scores, labels)) |
|
self.log_dict( |
|
{ |
|
"list_val_loss": self.list_val_loss, |
|
}, |
|
on_step=False, |
|
on_epoch=True, |
|
prog_bar=True, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx) |
|
|
|
self.log("test_loss", loss, sync_dist=True,) |
|
|
|
return loss |
|
|
|
def _common_step(self, batch, batch_idx): |
|
|
|
logits = self.forward( |
|
input_ids=batch["input_ids"].squeeze(), |
|
attention_mask=batch["attention_mask"].squeeze(), |
|
).squeeze() |
|
|
|
labels = batch["labels"] |
|
loss = self.loss_fn(logits, labels) |
|
|
|
return loss, logits, labels |
|
|
|
def predict_step(self, batch, batch_idx): |
|
loss, logits, labels = self._common_step(batch=batch, batch_idx=batch_idx) |
|
|
|
return logits, labels |
|
|
|
def configure_optimizers(self): |
|
|
|
optimizer = torch.optim.AdamW( |
|
params=self.parameters(), |
|
lr=self.learning_rate, |
|
betas=(0.9, 0.999), |
|
weight_decay=0.01, |
|
) |
|
|
|
|
|
scheduler = LinearWarmupCosineAnnealingLR( |
|
optimizer, |
|
|
|
|
|
|
|
warmup_epochs=self.warmup_epochs*self.steps_per_epoch, |
|
max_epochs=self.max_epochs*self.steps_per_epoch, |
|
) |
|
|
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"interval": "step", |
|
"frequency": 1, |
|
"reduce_on_plateau": False, |
|
"monitor": "val_loss", |
|
} |
|
} |
|
|
|
|
|
|