|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments |
|
from datasets import Dataset, load_from_disk, concatenate_datasets |
|
import os |
|
import torch |
|
import numpy as np |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix |
|
|
|
MODEL_NAME = "roberta-large" |
|
SAVE_MODEL_FOLDER = "img_intents_model" |
|
OUTPUT_DIR = "./results" |
|
NEG_NAME = "NEGATIVE" |
|
POS_NAME = "POSITIVE" |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(SAVE_MODEL_FOLDER) |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
training_args = torch.load(os.path.join(OUTPUT_DIR, "training_args.bin")) |
|
|
|
|
|
with open('test_positives.txt', 'r') as file: |
|
positives_texts = [line.strip() for line in file.readlines()] |
|
with open('test_negatives.txt', 'r') as file: |
|
negatives_texts = [line.strip() for line in file.readlines()] |
|
|
|
|
|
positives_dataset = Dataset.from_dict({'text': positives_texts, 'label': [1]*len(positives_texts)}) |
|
negatives_dataset = Dataset.from_dict({'text': negatives_texts, 'label': [0]*len(negatives_texts)}) |
|
|
|
|
|
test_dataset = concatenate_datasets([positives_dataset, negatives_dataset]) |
|
|
|
|
|
def preprocess_function(examples): |
|
|
|
return tokenizer(examples["text"], truncation=True, max_length=512, padding='max_length') |
|
|
|
test_dataset = test_dataset.map(preprocess_function, batched=True) |
|
|
|
|
|
test_dataset = test_dataset.remove_columns(["text"]).rename_column("label", "labels").with_format("torch") |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
) |
|
|
|
|
|
predictions, labels, _ = trainer.predict(test_dataset) |
|
|
|
|
|
binary_predictions = np.argmax(predictions, axis=1) |
|
|
|
|
|
accuracy = accuracy_score(labels, binary_predictions) |
|
precision = precision_score(labels, binary_predictions) |
|
recall = recall_score(labels, binary_predictions) |
|
f1 = f1_score(labels, binary_predictions) |
|
|
|
print(f"Overall accuracy: {accuracy}") |
|
print(f"Overall precision: {precision}") |
|
print(f"Overall recall: {recall}") |
|
print(f"Overall F1 score: {f1}") |
|
|
|
|
|
cm = confusion_matrix(labels, binary_predictions) |
|
|
|
for i, class_name in enumerate([NEG_NAME, POS_NAME]): |
|
total = cm[i].sum() |
|
correct = cm[i][i] |
|
loss = total - correct |
|
|
|
print(f"\n{class_name}:") |
|
print(f"Total: {total}") |
|
print(f"Confirmed: {correct}") |
|
print(f"Loss: {loss} ({loss / total * 100:.2f}%)") |
|
|