File size: 2,408 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501f2e5
 
 
 
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from create_setfit_model import model
from time import perf_counter
import os
import sys
from statistics import mean
from langchain.text_splitter import RecursiveCharacterTextSplitter
import torch
from collections import Counter
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

start = perf_counter()


dataset_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'financial_dataset'))
sys.path.append(dataset_dir)
from load_test_data import get_labels_df, get_texts

labels_dir = dataset_dir + '/csvs/'
df = get_labels_df(labels_dir)
texts_dir = dataset_dir + '/txts/'
texts = get_texts(texts_dir)
# df = df.iloc[:20, :]
# print(df.loc[:, 'Label'])
# texts = [texts[0]] + [texts[13]] + [texts[113]]
# texts = texts[:20]
print(len(df), len(texts))
print(mean(list(map(len, texts))))


text_splitter = RecursiveCharacterTextSplitter(
                    chunk_size=3200, chunk_overlap=200,
                    length_function = len, separators=[" ", ",", "\n"]
                )

labels = []
pred_labels = []

for text, (idx, (year, label, company)) in tqdm(zip(texts, df.iterrows())):
    documents = text_splitter.create_documents([text])
    texts = [document.page_content for document in documents]
    
    with torch.no_grad():
        model.model_head.eval()
        text_pred_labels = model(texts)
    
    pred_labels_counter = Counter(text_pred_labels)
    pred_label = pred_labels_counter.most_common(1)[0][0]
    
    labels.append(label)
    pred_labels.append(pred_label)


accuracy = accuracy_score(labels, pred_labels)
precision = precision_score(labels, pred_labels, average='weighted')
recall = recall_score(labels, pred_labels, average='weighted')
f1 = f1_score(labels, pred_labels, average='weighted')

confusion_mat = confusion_matrix(labels, pred_labels, normalize='true')

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

labels = ['hold', 'buy', 'sell']
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()

print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs')