| from typing import List, Dict | |
| import numpy as np | |
| import evaluate | |
| def compute_metrics_sentiment(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| acc = (preds == labels).mean().item() | |
| return {"accuracy": acc} | |
| def compute_metrics_ner(eval_pred, label_list: List[str]): | |
| seqeval = evaluate.load("seqeval") | |
| logits, labels = eval_pred | |
| preds = logits.argmax(-1) | |
| true_preds = [ | |
| [label_list[p] for (p, l) in zip(pred, lab) if l != -100] | |
| for pred, lab in zip(preds, labels) | |
| ] | |
| true_labels = [ | |
| [label_list[l] for (p, l) in zip(pred, lab) if l != -100] | |
| for pred, lab in zip(preds, labels) | |
| ] | |
| results = seqeval.compute(predictions=true_preds, references=true_labels) | |
| return { | |
| "precision": results["overall_precision"], | |
| "recall": results["overall_recall"], | |
| "f1": results["overall_f1"], | |
| "accuracy": results["overall_accuracy"], | |
| } | |