holiday_testing / test_models /test_model.py
svystun-taras's picture
tested the model on all dataset
501f2e5
raw
history blame contribute delete
No virus
2.41 kB
from create_setfit_model import model
from time import perf_counter
import os
import sys
from statistics import mean
from langchain.text_splitter import RecursiveCharacterTextSplitter
import torch
from collections import Counter
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
start = perf_counter()
dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset'))
sys.path.append(dataset_dir)
from load_test_data import get_labels_df, get_texts
labels_dir = dataset_dir + '/csvs/'
df = get_labels_df(labels_dir)
texts_dir = dataset_dir + '/txts/'
texts = get_texts(texts_dir)
# df = df.iloc[:20, :]
# print(df.loc[:, 'Label'])
# texts = [texts[0]] + [texts[13]] + [texts[113]]
# texts = texts[:20]
print(len(df), len(texts))
print(mean(list(map(len, texts))))
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=3200, chunk_overlap=200,
length_function = len, separators=[" ", ",", "\n"]
)
labels = []
pred_labels = []
for text, (idx, (year, label, company)) in tqdm(zip(texts, df.iterrows())):
documents = text_splitter.create_documents([text])
texts = [document.page_content for document in documents]
with torch.no_grad():
model.model_head.eval()
text_pred_labels = model(texts)
pred_labels_counter = Counter(text_pred_labels)
pred_label = pred_labels_counter.most_common(1)[0][0]
labels.append(label)
pred_labels.append(pred_label)
accuracy = accuracy_score(labels, pred_labels)
precision = precision_score(labels, pred_labels, average='weighted')
recall = recall_score(labels, pred_labels, average='weighted')
f1 = f1_score(labels, pred_labels, average='weighted')
confusion_mat = confusion_matrix(labels, pred_labels, normalize='true')
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
labels = ['hold', 'buy', 'sell']
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()
print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs')