File size: 3,369 Bytes
9ee675e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import lightning.pytorch as pl
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
from torch.nn import BCEWithLogitsLoss
from transformers import ( 
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_constant_schedule_with_warmup,
)

class FinanciaMultilabel(pl.LightningModule):
    
    def __init__(self, model, num_labels):
        super().__init__()
        self.model = model
        self.num_labels = num_labels
        self.loss = BCEWithLogitsLoss()
        self.validation_step_outputs = []

    def forward(self, input_ids, attention_mask, token_type_ids):
        return self.model(input_ids, attention_mask, token_type_ids).logits

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        token_type_ids = batch["token_type_ids"]
        outputs = self(input_ids, attention_mask, token_type_ids)
        loss = self.loss(outputs.view(-1,self.num_labels), labels.type_as(outputs).view(-1,self.num_labels))
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        token_type_ids = batch["token_type_ids"]
        outputs = self(input_ids, attention_mask, token_type_ids)
        loss = self.loss(outputs.view(-1,self.num_labels), labels.type_as(outputs).view(-1,self.num_labels))
        pred_labels = torch.sigmoid(outputs)
        info = {'val_loss': loss, 'pred_labels': pred_labels, 'labels': labels}
        self.validation_step_outputs.append(info)
        return 

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        pred_labels = torch.cat([x['pred_labels'] for x in outputs])
        labels = torch.cat([x['labels'] for x in outputs])
        threshold = 0.50
        pred_bools = pred_labels > threshold
        true_bools = labels == 1
        val_f1_accuracy = f1_score(true_bools.cpu(), pred_bools.cpu(), average='micro')*100
        val_flat_accuracy = accuracy_score(true_bools.cpu(), pred_bools.cpu())*100
        self.log('val_loss', avg_loss)
        self.log('val_f1_accuracy', val_f1_accuracy, prog_bar=True)
        self.log('val_flat_accuracy', val_flat_accuracy, prog_bar=True)
        self.validation_step_outputs.clear() 

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=2e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True, min_lr=1e-6)
        return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'monitor': 'val_loss'
                }
            }
        



def load_model(checkpoint_path, model, num_labels, device):
    model_hugginface = AutoModelForSequenceClassification.from_pretrained(model, num_labels=num_labels, ignore_mismatched_sizes=True)
    model = FinanciaMultilabel.load_from_checkpoint(
        checkpoint_path,
        model=model_hugginface,
        num_labels=num_labels,
        map_location=device
    )
    return model