JohnnyBoy00's picture
Upload evaluation.py
8841a3f
import numpy as np
import torch
from evaluate import load as load_metric
from sklearn.metrics import mean_squared_error
from tqdm.auto import tqdm
MAX_TARGET_LENGTH = 128
# load evaluation metrics
sacrebleu = load_metric('sacrebleu')
rouge = load_metric('rouge')
meteor = load_metric('meteor')
bertscore = load_metric('bertscore')
# use gpu if it's available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def flatten_list(l):
"""
Utility function to convert a list of lists into a flattened list
Params:
l (list of lists): list to be flattened
Returns:
A flattened list with the elements of the original list
"""
return [item for sublist in l for item in sublist]
def parse_float(value):
"""
Utility function to parse a string into a float
Params:
value (string): value to be converted to float
Returns:
The float representation of the given string, or None if the string could
not be converted to a float
"""
try:
float_value = float(value)
return float_value
except ValueError:
return None
def extract_scores(predictions):
"""
Utility function to extract the scores from the predictions of the model
Params:
predictions (list): complete model predictions
Returns:
scores (list): extracted scores from the model's predictions
"""
scores = []
# iterate through predictions and try to extract predicted score;
# if score could not be extracted, set it to None
for pred in predictions:
try:
score_string = pred.split(' ', 1)[0].strip()
score = parse_float(score_string)
except IndexError:
score = None
scores.append(score)
return scores
def extract_feedback(predictions):
"""
Utility function to extract the feedback from the predictions of the model
Params:
predictions (list): complete model predictions
Returns:
feedback (list): extracted feedback from the model's predictions
"""
feedback = []
# iterate through predictions and try to extract predicted feedback
for pred in predictions:
try:
fb = pred.split(':', 1)[1]
except IndexError:
try:
fb = pred.split(' ', 1)[1]
except IndexError:
fb = pred
feedback.append(fb.strip())
return feedback
def compute_rmse(predictions, labels):
"""
Utility function to compute the root mean squared error of the
score predictions in relation to the golden label scores
Params:
predictions (list): model score predictions
labels (list): golden label scores
Returns:
(float, int): rmse of valid samples and number of invalid samples
"""
# get indexes of valid score predictions
# (i.e., where the score is not None)
idx = np.where(np.array(predictions) != None)
# get size of the golden labels list and of
# the valid predictions array
labels_size = np.array(labels).size
valid_predictions_size = idx[0].size
# only compute rmse if valid score predictions were generated,
# otherwise set mse to 1
if valid_predictions_size > 0:
# calculate rmse from labels and predictions
valid_predictions = np.array(predictions)[idx]
score_labels = np.array(labels)[idx]
rmse = mean_squared_error(score_labels, valid_predictions, squared=False)
# cap mse at 1
if rmse > 1:
return 1, labels_size - valid_predictions_size
# return computed rmse and number of invalid samples
return rmse, labels_size - valid_predictions_size
else:
return 1, labels_size - valid_predictions_size
def compute_metrics(predictions, labels):
"""
Compute evaluation metrics from the predictions of the model
Params:
predictions (list): complete model predictions
labels (list): golden labels (previously tokenized)
Returns:
results (dict): dictionary with the computed evaluation metrics
"""
# extract feedback and labels from the model's predictions
predicted_feedback = extract_feedback(predictions)
predicted_scores = extract_scores(predictions)
# extract feedback and labels from the golden labels
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
reference_scores = [float(x.split('Feedback:', 1)[0].strip()) for x in labels]
# compute HF metrics
sacrebleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
bert_score = bertscore.compute(
predictions=predicted_feedback,
references=reference_feedback,
lang='de',
model_type='bert-base-multilingual-cased',
rescale_with_baseline=True)
# compute rmse of score predictions
rmse, _ = compute_rmse(predicted_scores, reference_scores)
results = {
'sacrebleu': sacrebleu_score,
'rouge': rouge_score,
'meteor': meteor_score,
'bert_score': np.array(bert_score['f1']).mean().item(),
'rmse': rmse
}
return results
def evaluate(model, tokenizer, dataloader):
"""
Evaluate model on the given dataset
Params:
model (PreTrainedModel): seq2seq model
tokenizer (PreTrainedTokenizer): tokenizer from HuggingFace
dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
Returns:
results (dict): dictionary with the computed evaluation metrics
predictions (list): list of the decoded predictions of the model
"""
decoded_preds, decoded_labels = [], []
model.eval()
# iterate through batchs in the dataloader
for batch in tqdm(dataloader):
with torch.no_grad():
batch = {k: v.to(device) for k, v in batch.items()}
# generate tokens from batch
generated_tokens = model.generate(
batch['input_ids'],
attention_mask=batch['attention_mask'],
max_length=MAX_TARGET_LENGTH
)
# get golden labels from batch
labels_batch = batch['labels']
# decode model predictions and golden labels
decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
decoded_preds.append(decoded_preds_batch)
decoded_labels.append(decoded_labels_batch)
# convert predictions and golden labels into flattened lists
predictions = flatten_list(decoded_preds)
labels = flatten_list(decoded_labels)
# compute metrics based on predictions and golden labels
results = compute_metrics(predictions, labels)
return results, predictions