JohnnyBoy00 commited on
Commit
028cc28
1 Parent(s): 31c4de0

Upload evaluation.py

Browse files
Files changed (1) hide show
  1. evaluation.py +174 -0
evaluation.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from evaluate import load as load_metric
5
+
6
+ from sklearn.metrics import accuracy_score, f1_score
7
+ from tqdm.auto import tqdm
8
+
9
+ MAX_TARGET_LENGTH = 128
10
+
11
+ # load evaluation metrics
12
+ sacrebleu = load_metric('sacrebleu')
13
+ rouge = load_metric('rouge')
14
+ meteor = load_metric('meteor')
15
+ bertscore = load_metric('bertscore')
16
+
17
+ # use gpu if it's available
18
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
+
20
+ def flatten_list(l):
21
+ """
22
+ Utility function to convert a list of lists into a flattened list
23
+
24
+ Params:
25
+ l (list of lists): list to be flattened
26
+ Returns:
27
+ A flattened list with the elements of the original list
28
+ """
29
+ return [item for sublist in l for item in sublist]
30
+
31
+ def extract_feedback(predictions):
32
+ """
33
+ Utility function to extract the feedback from the predictions of the model
34
+
35
+ Params:
36
+ predictions (list): complete model predictions
37
+ Returns:
38
+ feedback (list): extracted feedback from the model's predictions
39
+ """
40
+ feedback = []
41
+ # iterate through predictions and try to extract predicted feedback
42
+ for pred in predictions:
43
+ try:
44
+ fb = pred.split(':', 1)[1]
45
+ except IndexError:
46
+ try:
47
+ if pred.lower().startswith('partially correct'):
48
+ fb = pred.split(' ', 1)[2]
49
+ else:
50
+ fb = pred.split(' ', 1)[1]
51
+ except IndexError:
52
+ fb = pred
53
+ feedback.append(fb.strip())
54
+
55
+ return feedback
56
+
57
+ def extract_labels(predictions):
58
+ """
59
+ Utility function to extract the labels from the predictions of the model
60
+
61
+ Params:
62
+ predictions (list): complete model predictions
63
+ Returns:
64
+ feedback (list): extracted labels from the model's predictions
65
+ """
66
+ labels = []
67
+ for pred in predictions:
68
+ if pred.lower().startswith('correct'):
69
+ label = 'Correct'
70
+ elif pred.lower().startswith('partially correct'):
71
+ label = 'Partially correct'
72
+ elif pred.lower().startswith('incorrect'):
73
+ label = 'Incorrect'
74
+ else:
75
+ label = 'Unknown label'
76
+ labels.append(label)
77
+
78
+ return labels
79
+
80
+ def compute_metrics(predictions, labels):
81
+ """
82
+ Compute evaluation metrics from the predictions of the model
83
+
84
+ Params:
85
+ predictions (list): complete model predictions
86
+ labels (list): golden labels (previously tokenized)
87
+ Returns:
88
+ results (dict): dictionary with the computed evaluation metrics
89
+ predictions (list): list of the decoded predictions of the model
90
+ """
91
+ # extract feedback and labels from the model's predictions
92
+ predicted_feedback = extract_feedback(predictions)
93
+ predicted_labels = extract_labels(predictions)
94
+
95
+ # extract feedback and labels from the golden labels
96
+ reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
97
+ reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
98
+
99
+ # compute HF metrics
100
+ sacrebleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
101
+ rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
102
+ meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
103
+ bert_score = bertscore.compute(
104
+ predictions=predicted_feedback,
105
+ references=reference_feedback,
106
+ lang='de',
107
+ model_type='bert-base-multilingual-cased',
108
+ rescale_with_baseline=True)
109
+
110
+ # use sklearn to compute accuracy and f1 score
111
+ reference_labels_np = np.array(reference_labels)
112
+ accuracy = accuracy_score(reference_labels_np, predicted_labels)
113
+ f1_weighted = f1_score(reference_labels_np, predicted_labels, average='weighted')
114
+ f1_macro = f1_score(
115
+ reference_labels_np,
116
+ predicted_labels,
117
+ average='macro',
118
+ labels=['Incorrect', 'Partially correct', 'Correct'])
119
+
120
+ results = {
121
+ 'sacrebleu': sacrebleu_score,
122
+ 'rouge': rouge_score,
123
+ 'meteor': meteor_score,
124
+ 'bert_score': np.array(bert_score['f1']).mean().item(),
125
+ 'accuracy': accuracy,
126
+ 'f1_weighted': f1_weighted,
127
+ 'f1_macro': f1_macro
128
+ }
129
+
130
+ return results
131
+
132
+ def evaluate(model, tokenizer, dataloader):
133
+ """
134
+ Evaluate model on the given dataset
135
+
136
+ Params:
137
+ model (PreTrainedModel): seq2seq model
138
+ tokenizer (PreTrainedTokenizer): tokenizer from HuggingFace
139
+ dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
140
+ Returns:
141
+ results (dict): dictionary with the computed evaluation metrics
142
+ predictions (list): list of the decoded predictions of the model
143
+ """
144
+ decoded_preds, decoded_labels = [], []
145
+
146
+ model.eval()
147
+ # iterate through batchs in the dataloader
148
+ for batch in tqdm(dataloader):
149
+ with torch.no_grad():
150
+ batch = {k: v.to(device) for k, v in batch.items()}
151
+ # generate tokens from batch
152
+ generated_tokens = model.generate(
153
+ batch['input_ids'],
154
+ attention_mask=batch['attention_mask'],
155
+ max_length=MAX_TARGET_LENGTH
156
+ )
157
+ # get golden labels from batch
158
+ labels_batch = batch['labels']
159
+
160
+ # decode model predictions and golden labels
161
+ decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
162
+ decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
163
+
164
+ decoded_preds.append(decoded_preds_batch)
165
+ decoded_labels.append(decoded_labels_batch)
166
+
167
+ # convert predictions and golden labels into flattened lists
168
+ predictions = flatten_list(decoded_preds)
169
+ labels = flatten_list(decoded_labels)
170
+
171
+ # compute metrics based on predictions and golden labels
172
+ results = compute_metrics(predictions, labels)
173
+
174
+ return results, predictions