import torch import lightning.pytorch as pl from tqdm import tqdm from sklearn.metrics import f1_score, accuracy_score from torch.nn import BCEWithLogitsLoss from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, get_constant_schedule_with_warmup, ) class FinanciaMultilabel(pl.LightningModule): def __init__(self, model, num_labels): super().__init__() self.model = model self.num_labels = num_labels self.loss = BCEWithLogitsLoss() self.validation_step_outputs = [] def forward(self, input_ids, attention_mask, token_type_ids): return self.model(input_ids, attention_mask, token_type_ids).logits def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] token_type_ids = batch["token_type_ids"] outputs = self(input_ids, attention_mask, token_type_ids) loss = self.loss(outputs.view(-1,self.num_labels), labels.type_as(outputs).view(-1,self.num_labels)) self.log('train_loss', loss) return loss def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] token_type_ids = batch["token_type_ids"] outputs = self(input_ids, attention_mask, token_type_ids) loss = self.loss(outputs.view(-1,self.num_labels), labels.type_as(outputs).view(-1,self.num_labels)) pred_labels = torch.sigmoid(outputs) info = {'val_loss': loss, 'pred_labels': pred_labels, 'labels': labels} self.validation_step_outputs.append(info) return def on_validation_epoch_end(self): outputs = self.validation_step_outputs avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() pred_labels = torch.cat([x['pred_labels'] for x in outputs]) labels = torch.cat([x['labels'] for x in outputs]) threshold = 0.50 pred_bools = pred_labels > threshold true_bools = labels == 1 val_f1_accuracy = f1_score(true_bools.cpu(), pred_bools.cpu(), average='micro')*100 val_flat_accuracy = accuracy_score(true_bools.cpu(), pred_bools.cpu())*100 self.log('val_loss', avg_loss) self.log('val_f1_accuracy', val_f1_accuracy, prog_bar=True) self.log('val_flat_accuracy', val_flat_accuracy, prog_bar=True) self.validation_step_outputs.clear() def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=2e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True, min_lr=1e-6) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'monitor': 'val_loss' } } def load_model(checkpoint_path, model, num_labels, device): model_hugginface = AutoModelForSequenceClassification.from_pretrained(model, num_labels=num_labels, ignore_mismatched_sizes=True) model = FinanciaMultilabel.load_from_checkpoint( checkpoint_path, model=model_hugginface, num_labels=num_labels, map_location=device ) return model