pantelis-ninja commited on
Commit
f6480f8
1 Parent(s): f378b77

Major fixes for huggingface endpoints

Browse files
handler.py CHANGED
@@ -2,10 +2,11 @@ from typing import Dict, List, Any
2
  from transformers import AutoConfig, AutoTokenizer
3
  from src.models import DNikudModel, ModelConfig
4
  from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
5
- from src.utiles_data import Nikud
6
- from src.models_utils import predict_single
7
  import torch
8
  import os
 
9
 
10
 
11
  class EndpointHandler:
@@ -22,28 +23,90 @@ class EndpointHandler:
22
  len(Nikud.label_2_id["sin"]),
23
  device=self.DEVICE,
24
  ).to(self.DEVICE)
 
 
 
 
25
 
26
  def back_2_text(self, labels, text):
27
  nikud = Nikud()
28
  new_line = ""
 
29
  for indx_char, c in enumerate(text):
30
  new_line += (
31
  c
32
- + nikud.id_2_char(labels[0][1][1], "dagesh")
33
- + nikud.id_2_char(labels[0][1][2], "sin")
34
- + nikud.id_2_char(labels[0][1][0], "nikud")
35
  )
36
  print(indx_char, c)
37
  print(labels)
38
  return new_line
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def predict_single_text(
41
  self,
42
  text,
43
  ):
44
- data = self.tokenizer(text, return_tensors="pt")
45
- all_labels = predict_single(self.model, data, self.DEVICE)
46
- return all_labels
 
 
 
 
 
 
 
 
 
 
47
 
48
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
49
  """
@@ -59,5 +122,5 @@ class EndpointHandler:
59
  # result = []
60
  # for pred in prediction:
61
  # result.append(self.back_2_text(pred, inputs))
62
- result = self.back_2_text(prediction, inputs)
63
- return result
 
2
  from transformers import AutoConfig, AutoTokenizer
3
  from src.models import DNikudModel, ModelConfig
4
  from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
5
+ from src.utiles_data import Nikud, NikudDataset
6
+ from src.models_utils import predict_single, predict
7
  import torch
8
  import os
9
+ from tqdm import tqdm
10
 
11
 
12
  class EndpointHandler:
 
23
  len(Nikud.label_2_id["sin"]),
24
  device=self.DEVICE,
25
  ).to(self.DEVICE)
26
+ state_dict_model = self.model.state_dict()
27
+ state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
28
+ self.model.load_state_dict(state_dict_model)
29
+ self.max_length = MAX_LENGTH_SEN
30
 
31
  def back_2_text(self, labels, text):
32
  nikud = Nikud()
33
  new_line = ""
34
+
35
  for indx_char, c in enumerate(text):
36
  new_line += (
37
  c
38
+ + nikud.id_2_char(labels[indx_char][1][1], "dagesh")
39
+ + nikud.id_2_char(labels[indx_char][1][2], "sin")
40
+ + nikud.id_2_char(labels[indx_char][1][0], "nikud")
41
  )
42
  print(indx_char, c)
43
  print(labels)
44
  return new_line
45
 
46
+ def prepare_data(self, data, name="train"):
47
+ print("Data = ", data)
48
+ dataset = []
49
+ for index, (sentence, label) in tqdm(
50
+ enumerate(data), desc=f"Prepare data {name}"
51
+ ):
52
+ encoded_sequence = self.tokenizer.encode_plus(
53
+ sentence,
54
+ add_special_tokens=True,
55
+ max_length=self.max_length,
56
+ padding="max_length",
57
+ truncation=True,
58
+ return_attention_mask=True,
59
+ return_tensors="pt",
60
+ )
61
+ label_lists = [
62
+ [letter.nikud, letter.dagesh, letter.sin] for letter in label
63
+ ]
64
+ label = torch.tensor(
65
+ [
66
+ [
67
+ Nikud.PAD_OR_IRRELEVANT,
68
+ Nikud.PAD_OR_IRRELEVANT,
69
+ Nikud.PAD_OR_IRRELEVANT,
70
+ ]
71
+ ]
72
+ + label_lists[: (self.max_length - 1)]
73
+ + [
74
+ [
75
+ Nikud.PAD_OR_IRRELEVANT,
76
+ Nikud.PAD_OR_IRRELEVANT,
77
+ Nikud.PAD_OR_IRRELEVANT,
78
+ ]
79
+ for i in range(self.max_length - len(label) - 1)
80
+ ]
81
+ )
82
+
83
+ dataset.append(
84
+ (
85
+ encoded_sequence["input_ids"][0],
86
+ encoded_sequence["attention_mask"][0],
87
+ label,
88
+ )
89
+ )
90
+
91
+ self.prepered_data = dataset
92
+
93
  def predict_single_text(
94
  self,
95
  text,
96
  ):
97
+ dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
98
+ data, orig_data = dataset.read_single_text(text)
99
+ print("data", data, len(data))
100
+ dataset.prepare_data(name="inference")
101
+ mtb_prediction_dl = torch.utils.data.DataLoader(
102
+ dataset.prepered_data, batch_size=BATCH_SIZE
103
+ )
104
+ # print("dataset", dataset, len(dataset))
105
+ # data = self.tokenizer(text, return_tensors="pt")
106
+ all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
107
+ text_data_with_labels = dataset.back_2_text(labels=all_labels)
108
+ # all_labels = predict_single(self.model, dataset, self.DEVICE)
109
+ return text_data_with_labels
110
 
111
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
112
  """
 
122
  # result = []
123
  # for pred in prediction:
124
  # result.append(self.back_2_text(pred, inputs))
125
+ # result = self.back_2_text(prediction, inputs)
126
+ return prediction
models/Dnikud_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31b1bb3dc66cebf70ad4bfa52d77257d92c745e6609f0023108e91041447e754
3
+ size 446945642
src/models_utils.py CHANGED
@@ -94,15 +94,17 @@ def predict_single(model, data, device="cpu"):
94
 
95
  all_labels = None
96
  with torch.no_grad():
97
- inputs = data["input_ids"].to(device)
98
- attention_mask = data["attention_mask"].to(device)
 
 
99
 
100
- # mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
101
- # mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
102
- # mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
103
 
104
  nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
105
- print(nikud_probs, dagesh_probs, sin_probs)
106
 
107
  pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
108
  inputs.shape[0], inputs.shape[1], 1
@@ -114,9 +116,9 @@ def predict_single(model, data, device="cpu"):
114
  inputs.shape[0], inputs.shape[1], 1
115
  )
116
 
117
- # pred_nikud[mask_cant_be_nikud] = -1
118
- # pred_dagesh[mask_cant_be_dagesh] = -1
119
- # pred_sin[mask_cant_be_sin] = -1
120
  # print(pred_nikud, pred_dagesh, pred_sin)
121
  pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
122
  print(pred_labels)
 
94
 
95
  all_labels = None
96
  with torch.no_grad():
97
+ (inputs, attention_mask, labels_demo) = data
98
+ inputs = inputs.to(device)
99
+ attention_mask = attention_mask.to(device)
100
+ labels_demo = labels_demo.to(device)
101
 
102
+ mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
103
+ mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
104
+ mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
105
 
106
  nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
107
+ print("model output: ", nikud_probs, dagesh_probs, sin_probs)
108
 
109
  pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
110
  inputs.shape[0], inputs.shape[1], 1
 
116
  inputs.shape[0], inputs.shape[1], 1
117
  )
118
 
119
+ pred_nikud[mask_cant_be_nikud] = -1
120
+ pred_dagesh[mask_cant_be_dagesh] = -1
121
+ pred_sin[mask_cant_be_sin] = -1
122
  # print(pred_nikud, pred_dagesh, pred_sin)
123
  pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
124
  print(pred_labels)
src/utiles_data.py CHANGED
@@ -370,6 +370,8 @@ class NikudDataset(Dataset):
370
  self.max_length = max_length
371
  self.tokenizer = tokenizer
372
  self.is_train = is_train
 
 
373
  if folder is not None:
374
  self.data, self.origin_data = self.read_data_folder(folder, logger)
375
  elif file is not None:
@@ -453,6 +455,65 @@ class NikudDataset(Dataset):
453
 
454
  return data, orig_data
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  def split_text(self, file_data):
457
  file_data = file_data.replace("\n", f"\n{unique_key}")
458
  data_list = file_data.split(unique_key)
 
370
  self.max_length = max_length
371
  self.tokenizer = tokenizer
372
  self.is_train = is_train
373
+ self.data = None
374
+ self.origin_data = None
375
  if folder is not None:
376
  self.data, self.origin_data = self.read_data_folder(folder, logger)
377
  elif file is not None:
 
455
 
456
  return data, orig_data
457
 
458
+ def read_single_text(self, text: str, logger=None) -> List[Tuple[str, list]]:
459
+ # msg = f"read file: {filepath}"
460
+ # if logger:
461
+ # logger.debug(msg)
462
+ # else:
463
+ # print(msg)
464
+ data = []
465
+ orig_data = []
466
+ # with open(filepath, "r", encoding="utf-8") as file:
467
+ # file_data = file.read()
468
+ data_list = self.split_text(text)
469
+ # print("data_list", data_list)
470
+ for sen in tqdm(data_list, desc=f"Source: {data}"):
471
+ if sen == "":
472
+ continue
473
+
474
+ labels = []
475
+ text = ""
476
+ text_org = ""
477
+ index = 0
478
+ sentence_length = len(sen)
479
+ while index < sentence_length:
480
+ if (
481
+ ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
482
+ or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
483
+ or ord(sen[index]) == Nikud.nikud_dict["METEG"]
484
+ ):
485
+ index += 1
486
+ continue
487
+
488
+ label = []
489
+ l = Letter(sen[index])
490
+ if not (l.letter not in Nikud.all_nikud_chr):
491
+ if sen[index - 1] == "\n":
492
+ index += 1
493
+ continue
494
+ assert l.letter not in Nikud.all_nikud_chr
495
+ if sen[index] in Letters.hebrew:
496
+ index += 1
497
+ while (
498
+ index < sentence_length
499
+ and ord(sen[index]) in Nikud.all_nikud_ord
500
+ ):
501
+ label.append(ord(sen[index]))
502
+ index += 1
503
+ else:
504
+ index += 1
505
+
506
+ l.get_label_letter(label)
507
+ text += l.normalized
508
+ text_org += l.letter
509
+ labels.append(l)
510
+
511
+ data.append((text, labels))
512
+ orig_data.append(text_org)
513
+ self.data = data
514
+ self.origin_data = orig_data
515
+ return data, orig_data
516
+
517
  def split_text(self, file_data):
518
  file_data = file_data.replace("\n", f"\n{unique_key}")
519
  data_list = file_data.split(unique_key)