File size: 760 Bytes
932d265 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import BertForSequenceClassification, BertTokenizer
def compute_metrics(pred):
labels = pred.label_ids
preds = np.argmax(pred.predictions, axis=1)
acc = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
return {
"accuracy": acc,
"precision": precision,
"recall": recall,
"f1": f1,
}
def load_model_and_tokenizer(model_dir, device):
model = BertForSequenceClassification.from_pretrained(model_dir).to(device)
tokenizer = BertTokenizer.from_pretrained(model_dir)
return model, tokenizer
|