pantelis-ninja
commited on
Commit
•
f6480f8
1
Parent(s):
f378b77
Major fixes for huggingface endpoints
Browse files- handler.py +73 -10
- models/Dnikud_best_model.pth +3 -0
- src/models_utils.py +11 -9
- src/utiles_data.py +61 -0
handler.py
CHANGED
@@ -2,10 +2,11 @@ from typing import Dict, List, Any
|
|
2 |
from transformers import AutoConfig, AutoTokenizer
|
3 |
from src.models import DNikudModel, ModelConfig
|
4 |
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
|
5 |
-
from src.utiles_data import Nikud
|
6 |
-
from src.models_utils import predict_single
|
7 |
import torch
|
8 |
import os
|
|
|
9 |
|
10 |
|
11 |
class EndpointHandler:
|
@@ -22,28 +23,90 @@ class EndpointHandler:
|
|
22 |
len(Nikud.label_2_id["sin"]),
|
23 |
device=self.DEVICE,
|
24 |
).to(self.DEVICE)
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def back_2_text(self, labels, text):
|
27 |
nikud = Nikud()
|
28 |
new_line = ""
|
|
|
29 |
for indx_char, c in enumerate(text):
|
30 |
new_line += (
|
31 |
c
|
32 |
-
+ nikud.id_2_char(labels[
|
33 |
-
+ nikud.id_2_char(labels[
|
34 |
-
+ nikud.id_2_char(labels[
|
35 |
)
|
36 |
print(indx_char, c)
|
37 |
print(labels)
|
38 |
return new_line
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def predict_single_text(
|
41 |
self,
|
42 |
text,
|
43 |
):
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
49 |
"""
|
@@ -59,5 +122,5 @@ class EndpointHandler:
|
|
59 |
# result = []
|
60 |
# for pred in prediction:
|
61 |
# result.append(self.back_2_text(pred, inputs))
|
62 |
-
result = self.back_2_text(prediction, inputs)
|
63 |
-
return
|
|
|
2 |
from transformers import AutoConfig, AutoTokenizer
|
3 |
from src.models import DNikudModel, ModelConfig
|
4 |
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
|
5 |
+
from src.utiles_data import Nikud, NikudDataset
|
6 |
+
from src.models_utils import predict_single, predict
|
7 |
import torch
|
8 |
import os
|
9 |
+
from tqdm import tqdm
|
10 |
|
11 |
|
12 |
class EndpointHandler:
|
|
|
23 |
len(Nikud.label_2_id["sin"]),
|
24 |
device=self.DEVICE,
|
25 |
).to(self.DEVICE)
|
26 |
+
state_dict_model = self.model.state_dict()
|
27 |
+
state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
|
28 |
+
self.model.load_state_dict(state_dict_model)
|
29 |
+
self.max_length = MAX_LENGTH_SEN
|
30 |
|
31 |
def back_2_text(self, labels, text):
|
32 |
nikud = Nikud()
|
33 |
new_line = ""
|
34 |
+
|
35 |
for indx_char, c in enumerate(text):
|
36 |
new_line += (
|
37 |
c
|
38 |
+
+ nikud.id_2_char(labels[indx_char][1][1], "dagesh")
|
39 |
+
+ nikud.id_2_char(labels[indx_char][1][2], "sin")
|
40 |
+
+ nikud.id_2_char(labels[indx_char][1][0], "nikud")
|
41 |
)
|
42 |
print(indx_char, c)
|
43 |
print(labels)
|
44 |
return new_line
|
45 |
|
46 |
+
def prepare_data(self, data, name="train"):
|
47 |
+
print("Data = ", data)
|
48 |
+
dataset = []
|
49 |
+
for index, (sentence, label) in tqdm(
|
50 |
+
enumerate(data), desc=f"Prepare data {name}"
|
51 |
+
):
|
52 |
+
encoded_sequence = self.tokenizer.encode_plus(
|
53 |
+
sentence,
|
54 |
+
add_special_tokens=True,
|
55 |
+
max_length=self.max_length,
|
56 |
+
padding="max_length",
|
57 |
+
truncation=True,
|
58 |
+
return_attention_mask=True,
|
59 |
+
return_tensors="pt",
|
60 |
+
)
|
61 |
+
label_lists = [
|
62 |
+
[letter.nikud, letter.dagesh, letter.sin] for letter in label
|
63 |
+
]
|
64 |
+
label = torch.tensor(
|
65 |
+
[
|
66 |
+
[
|
67 |
+
Nikud.PAD_OR_IRRELEVANT,
|
68 |
+
Nikud.PAD_OR_IRRELEVANT,
|
69 |
+
Nikud.PAD_OR_IRRELEVANT,
|
70 |
+
]
|
71 |
+
]
|
72 |
+
+ label_lists[: (self.max_length - 1)]
|
73 |
+
+ [
|
74 |
+
[
|
75 |
+
Nikud.PAD_OR_IRRELEVANT,
|
76 |
+
Nikud.PAD_OR_IRRELEVANT,
|
77 |
+
Nikud.PAD_OR_IRRELEVANT,
|
78 |
+
]
|
79 |
+
for i in range(self.max_length - len(label) - 1)
|
80 |
+
]
|
81 |
+
)
|
82 |
+
|
83 |
+
dataset.append(
|
84 |
+
(
|
85 |
+
encoded_sequence["input_ids"][0],
|
86 |
+
encoded_sequence["attention_mask"][0],
|
87 |
+
label,
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
self.prepered_data = dataset
|
92 |
+
|
93 |
def predict_single_text(
|
94 |
self,
|
95 |
text,
|
96 |
):
|
97 |
+
dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
|
98 |
+
data, orig_data = dataset.read_single_text(text)
|
99 |
+
print("data", data, len(data))
|
100 |
+
dataset.prepare_data(name="inference")
|
101 |
+
mtb_prediction_dl = torch.utils.data.DataLoader(
|
102 |
+
dataset.prepered_data, batch_size=BATCH_SIZE
|
103 |
+
)
|
104 |
+
# print("dataset", dataset, len(dataset))
|
105 |
+
# data = self.tokenizer(text, return_tensors="pt")
|
106 |
+
all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
|
107 |
+
text_data_with_labels = dataset.back_2_text(labels=all_labels)
|
108 |
+
# all_labels = predict_single(self.model, dataset, self.DEVICE)
|
109 |
+
return text_data_with_labels
|
110 |
|
111 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
112 |
"""
|
|
|
122 |
# result = []
|
123 |
# for pred in prediction:
|
124 |
# result.append(self.back_2_text(pred, inputs))
|
125 |
+
# result = self.back_2_text(prediction, inputs)
|
126 |
+
return prediction
|
models/Dnikud_best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31b1bb3dc66cebf70ad4bfa52d77257d92c745e6609f0023108e91041447e754
|
3 |
+
size 446945642
|
src/models_utils.py
CHANGED
@@ -94,15 +94,17 @@ def predict_single(model, data, device="cpu"):
|
|
94 |
|
95 |
all_labels = None
|
96 |
with torch.no_grad():
|
97 |
-
inputs = data
|
98 |
-
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
|
104 |
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
105 |
-
print(nikud_probs, dagesh_probs, sin_probs)
|
106 |
|
107 |
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
|
108 |
inputs.shape[0], inputs.shape[1], 1
|
@@ -114,9 +116,9 @@ def predict_single(model, data, device="cpu"):
|
|
114 |
inputs.shape[0], inputs.shape[1], 1
|
115 |
)
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
# print(pred_nikud, pred_dagesh, pred_sin)
|
121 |
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
|
122 |
print(pred_labels)
|
|
|
94 |
|
95 |
all_labels = None
|
96 |
with torch.no_grad():
|
97 |
+
(inputs, attention_mask, labels_demo) = data
|
98 |
+
inputs = inputs.to(device)
|
99 |
+
attention_mask = attention_mask.to(device)
|
100 |
+
labels_demo = labels_demo.to(device)
|
101 |
|
102 |
+
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
|
103 |
+
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
|
104 |
+
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
|
105 |
|
106 |
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
107 |
+
print("model output: ", nikud_probs, dagesh_probs, sin_probs)
|
108 |
|
109 |
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
|
110 |
inputs.shape[0], inputs.shape[1], 1
|
|
|
116 |
inputs.shape[0], inputs.shape[1], 1
|
117 |
)
|
118 |
|
119 |
+
pred_nikud[mask_cant_be_nikud] = -1
|
120 |
+
pred_dagesh[mask_cant_be_dagesh] = -1
|
121 |
+
pred_sin[mask_cant_be_sin] = -1
|
122 |
# print(pred_nikud, pred_dagesh, pred_sin)
|
123 |
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
|
124 |
print(pred_labels)
|
src/utiles_data.py
CHANGED
@@ -370,6 +370,8 @@ class NikudDataset(Dataset):
|
|
370 |
self.max_length = max_length
|
371 |
self.tokenizer = tokenizer
|
372 |
self.is_train = is_train
|
|
|
|
|
373 |
if folder is not None:
|
374 |
self.data, self.origin_data = self.read_data_folder(folder, logger)
|
375 |
elif file is not None:
|
@@ -453,6 +455,65 @@ class NikudDataset(Dataset):
|
|
453 |
|
454 |
return data, orig_data
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
def split_text(self, file_data):
|
457 |
file_data = file_data.replace("\n", f"\n{unique_key}")
|
458 |
data_list = file_data.split(unique_key)
|
|
|
370 |
self.max_length = max_length
|
371 |
self.tokenizer = tokenizer
|
372 |
self.is_train = is_train
|
373 |
+
self.data = None
|
374 |
+
self.origin_data = None
|
375 |
if folder is not None:
|
376 |
self.data, self.origin_data = self.read_data_folder(folder, logger)
|
377 |
elif file is not None:
|
|
|
455 |
|
456 |
return data, orig_data
|
457 |
|
458 |
+
def read_single_text(self, text: str, logger=None) -> List[Tuple[str, list]]:
|
459 |
+
# msg = f"read file: {filepath}"
|
460 |
+
# if logger:
|
461 |
+
# logger.debug(msg)
|
462 |
+
# else:
|
463 |
+
# print(msg)
|
464 |
+
data = []
|
465 |
+
orig_data = []
|
466 |
+
# with open(filepath, "r", encoding="utf-8") as file:
|
467 |
+
# file_data = file.read()
|
468 |
+
data_list = self.split_text(text)
|
469 |
+
# print("data_list", data_list)
|
470 |
+
for sen in tqdm(data_list, desc=f"Source: {data}"):
|
471 |
+
if sen == "":
|
472 |
+
continue
|
473 |
+
|
474 |
+
labels = []
|
475 |
+
text = ""
|
476 |
+
text_org = ""
|
477 |
+
index = 0
|
478 |
+
sentence_length = len(sen)
|
479 |
+
while index < sentence_length:
|
480 |
+
if (
|
481 |
+
ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
|
482 |
+
or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
|
483 |
+
or ord(sen[index]) == Nikud.nikud_dict["METEG"]
|
484 |
+
):
|
485 |
+
index += 1
|
486 |
+
continue
|
487 |
+
|
488 |
+
label = []
|
489 |
+
l = Letter(sen[index])
|
490 |
+
if not (l.letter not in Nikud.all_nikud_chr):
|
491 |
+
if sen[index - 1] == "\n":
|
492 |
+
index += 1
|
493 |
+
continue
|
494 |
+
assert l.letter not in Nikud.all_nikud_chr
|
495 |
+
if sen[index] in Letters.hebrew:
|
496 |
+
index += 1
|
497 |
+
while (
|
498 |
+
index < sentence_length
|
499 |
+
and ord(sen[index]) in Nikud.all_nikud_ord
|
500 |
+
):
|
501 |
+
label.append(ord(sen[index]))
|
502 |
+
index += 1
|
503 |
+
else:
|
504 |
+
index += 1
|
505 |
+
|
506 |
+
l.get_label_letter(label)
|
507 |
+
text += l.normalized
|
508 |
+
text_org += l.letter
|
509 |
+
labels.append(l)
|
510 |
+
|
511 |
+
data.append((text, labels))
|
512 |
+
orig_data.append(text_org)
|
513 |
+
self.data = data
|
514 |
+
self.origin_data = orig_data
|
515 |
+
return data, orig_data
|
516 |
+
|
517 |
def split_text(self, file_data):
|
518 |
file_data = file_data.replace("\n", f"\n{unique_key}")
|
519 |
data_list = file_data.split(unique_key)
|