File size: 4,818 Bytes
cb09873
 
 
 
 
 
 
67d83f0
cb09873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67d83f0
440014c
cb09873
67d83f0
 
 
cb09873
67d83f0
cb09873
 
 
 
 
 
67d83f0
cb09873
440014c
cb09873
67d83f0
cb09873
 
 
 
 
67d83f0
 
cb09873
 
 
 
67d83f0
cb09873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67d83f0
cb09873
67d83f0
cb09873
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, pipeline, DataCollatorWithPadding
from sklearn.metrics import accuracy_score, f1_score
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



class TransformersSequenceClassifier:
    def __init__(self,
                 model_output_dir,
                 num_labels,
                 tokenizer : AutoTokenizer,
                 model_checkpoint="distilbert-base-uncased"
                 ):
        self.model_output_dir = model_output_dir
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels).to(device)

    def tokenizer_batch(self, batch):
        return self.tokenizer(batch["inputs"], truncation=True) #, max_len=386

    def tokenize_dataset(self, dataset):
        return dataset.map(self.tokenizer_batch, batched=True, remove_columns=('inputs', '__index_level_0__'))
    def train(self, train_dataset, eval_dataset, batch_size, epochs):
        data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer, padding='longest')
        training_args = TrainingArguments(output_dir=self.model_output_dir,
                                          num_train_epochs=epochs,
                                          learning_rate=2e-5,
                                          per_device_train_batch_size=batch_size,
                                          per_device_eval_batch_size=batch_size,
                                          weight_decay=0.01,
                                          evaluation_strategy="epoch",
                                          save_strategy='epoch',
                                          disable_tqdm=False,
                                          logging_steps=len(train_dataset)// batch_size,
                                          push_to_hub=True,
                                          load_best_model_at_end=True,
                                          log_level="error")
        self.trainer = Trainer(
            model=self.model,
            args=training_args,
            compute_metrics=self._compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer,
            data_collator=data_collator
        )
        self.trainer.train()
        self.trainer.push_to_hub(commit_message="Training completed!")

    @staticmethod
    def _compute_metrics(pred):
        labels = pred.label_ids
        preds = pred.predictions.argmax(-1)
        f1 = f1_score(labels, preds, average="weighted")
        acc = accuracy_score(labels, preds)
        return {"accuracy": acc, "f1": f1}

    def forward_pass_with_label(self, batch):
        # Place all input tensors on the same device as the model
        inputs = {k:v.to(device) for k,v in batch.items()
                if k in self.tokenizer.model_input_names}

        with torch.no_grad():
            output = self.model(**inputs)
            pred_label = torch.argmax(output.logits, axis=-1)
            loss = F.cross_entropy(output.logits, batch["label"].to(device), 
                                reduction="none")

        # Place outputs on CPU for compatibility with other dataset columns
        return {"loss": loss.cpu().numpy(), 
                "predicted_label": pred_label.cpu().numpy()}

    def compute_loss_per_pred(self, valid_dataset):
        # Compute loss values
        return valid_dataset.map(self.forward_pass_with_label, batched=True, batch_size=16)

    @staticmethod
    def plot_confusion_matrix(y_preds, y_true, labels):
        cm = confusion_matrix(y_true, y_preds, normalize="true")
        fig, ax = plt.subplots(figsize=(6, 6))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
        disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
        plt.title("Normalized confusion matrix")
        plt.show()

    def predict_valid_data(self, valid_dataset):
        #trainer = Trainer(model=self.model)
        preds_output = self.trainer.predict(valid_dataset)
        print(preds_output.metrics)
        y_preds = np.argmax(preds_output.predictions, axis=1)
        return y_preds
    
    @staticmethod
    def predict_test_data(model_checkpoint, test_list: List[str]) -> List:
        pipe_classifier = pipeline("text-classification", model=model_checkpoint)
        preds = pipe_classifier(test_list, return_all_scores=True)
        return preds