refactor
Browse files- helper_functions.py +7 -7
helper_functions.py
CHANGED
@@ -5,14 +5,16 @@ from transformers import BatchEncoding, PreTrainedTokenizerBase
|
|
5 |
from typing import Optional
|
6 |
from torch import Tensor
|
7 |
|
8 |
-
#
|
9 |
-
|
|
|
10 |
|
11 |
-
#
|
12 |
-
|
|
|
13 |
|
14 |
# Charger le label encoder
|
15 |
-
with open(
|
16 |
label_encoder = pickle.load(f)
|
17 |
|
18 |
class_labels = {
|
@@ -28,8 +30,6 @@ class_labels = {
|
|
28 |
4: ('Computational_Geometry', 'orange', '#fd7e14')
|
29 |
}
|
30 |
|
31 |
-
|
32 |
-
|
33 |
def predict_class(text):
|
34 |
# Tokenisation du texte
|
35 |
inputs = transform_list_of_texts(text, tokenizer, 510, 510, 1, 2550)
|
|
|
5 |
from typing import Optional
|
6 |
from torch import Tensor
|
7 |
|
8 |
+
# Load the model
|
9 |
+
model_path = "Sayado/Model_PFE"
|
10 |
+
model = BertForSequenceClassification.from_pretrained(model_path)
|
11 |
|
12 |
+
# Load the tokenizer
|
13 |
+
tokenizer_path = "Sayado/Model_PFE"
|
14 |
+
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
15 |
|
16 |
# Charger le label encoder
|
17 |
+
with open("label_encoder.pkl", "rb") as f:
|
18 |
label_encoder = pickle.load(f)
|
19 |
|
20 |
class_labels = {
|
|
|
30 |
4: ('Computational_Geometry', 'orange', '#fd7e14')
|
31 |
}
|
32 |
|
|
|
|
|
33 |
def predict_class(text):
|
34 |
# Tokenisation du texte
|
35 |
inputs = transform_list_of_texts(text, tokenizer, 510, 510, 1, 2550)
|