import numpy as np | |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
from transformers import BertForSequenceClassification, BertTokenizer | |
def compute_metrics(pred): | |
labels = pred.label_ids | |
preds = np.argmax(pred.predictions, axis=1) | |
acc = accuracy_score(labels, preds) | |
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary") | |
return { | |
"accuracy": acc, | |
"precision": precision, | |
"recall": recall, | |
"f1": f1, | |
} | |
def load_model_and_tokenizer(model_dir, device): | |
model = BertForSequenceClassification.from_pretrained(model_dir).to(device) | |
tokenizer = BertTokenizer.from_pretrained(model_dir) | |
return model, tokenizer | |