pantelis-ninja commited on
Commit
f378b77
โ€ข
1 Parent(s): ef82fa0

add code for huggingface endpoints capability

Browse files
handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoConfig, AutoTokenizer
3
+ from src.models import DNikudModel, ModelConfig
4
+ from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
5
+ from src.utiles_data import Nikud
6
+ from src.models_utils import predict_single
7
+ import torch
8
+ import os
9
+
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path=""):
13
+ self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
16
+ dir_model_config = os.path.join("models", "config.yml")
17
+ self.config = ModelConfig.load_from_file(dir_model_config)
18
+ self.model = DNikudModel(
19
+ self.config,
20
+ len(Nikud.label_2_id["nikud"]),
21
+ len(Nikud.label_2_id["dagesh"]),
22
+ len(Nikud.label_2_id["sin"]),
23
+ device=self.DEVICE,
24
+ ).to(self.DEVICE)
25
+
26
+ def back_2_text(self, labels, text):
27
+ nikud = Nikud()
28
+ new_line = ""
29
+ for indx_char, c in enumerate(text):
30
+ new_line += (
31
+ c
32
+ + nikud.id_2_char(labels[0][1][1], "dagesh")
33
+ + nikud.id_2_char(labels[0][1][2], "sin")
34
+ + nikud.id_2_char(labels[0][1][0], "nikud")
35
+ )
36
+ print(indx_char, c)
37
+ print(labels)
38
+ return new_line
39
+
40
+ def predict_single_text(
41
+ self,
42
+ text,
43
+ ):
44
+ data = self.tokenizer(text, return_tensors="pt")
45
+ all_labels = predict_single(self.model, data, self.DEVICE)
46
+ return all_labels
47
+
48
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
49
+ """
50
+ data args:
51
+ """
52
+
53
+ # get inputs
54
+ inputs = data.pop("text", data)
55
+
56
+ # run normal prediction
57
+ prediction = self.predict_single_text(inputs)
58
+
59
+ # result = []
60
+ # for pred in prediction:
61
+ # result.append(self.back_2_text(pred, inputs))
62
+ result = self.back_2_text(prediction, inputs)
63
+ return result
main.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+ import logging
7
+ from logging.handlers import RotatingFileHandler
8
+ from pathlib import Path
9
+
10
+ # ML
11
+ import torch
12
+ import torch.nn as nn
13
+ from transformers import AutoConfig, AutoTokenizer
14
+
15
+ # DL
16
+ from src.models import DNikudModel, ModelConfig
17
+ from src.models_utils import training, evaluate, predict
18
+ from src.plot_helpers import (
19
+ generate_plot_by_nikud_dagesh_sin_dict,
20
+ generate_word_and_letter_accuracy_plot,
21
+ )
22
+ from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
23
+ from src.utiles_data import (
24
+ NikudDataset,
25
+ Nikud,
26
+ create_missing_folders,
27
+ extract_text_to_compare_nakdimon,
28
+ )
29
+
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ assert DEVICE == "cuda"
32
+
33
+
34
+ def get_logger(
35
+ log_level, name_func, date_time=datetime.now().strftime("%d_%m_%y__%H_%M")
36
+ ):
37
+ log_location = os.path.join(
38
+ os.path.join(Path(__file__).parent, "logging"),
39
+ f"log_model_{name_func}_{date_time}",
40
+ )
41
+ create_missing_folders(log_location)
42
+
43
+ log_format = "%(asctime)s %(levelname)-8s Thread_%(thread)-6d ::: %(funcName)s(%(lineno)d) ::: %(message)s"
44
+ logger = logging.getLogger("algo")
45
+ logger.setLevel(getattr(logging, log_level))
46
+ cnsl_log_formatter = logging.Formatter(log_format)
47
+ cnsl_handler = logging.StreamHandler()
48
+ cnsl_handler.setFormatter(cnsl_log_formatter)
49
+ cnsl_handler.setLevel(log_level)
50
+ logger.addHandler(cnsl_handler)
51
+
52
+ create_missing_folders(log_location)
53
+
54
+ file_location = os.path.join(log_location, "Diacritization_Model_DEBUG.log")
55
+ file_log_formatter = logging.Formatter(log_format)
56
+
57
+ SINGLE_LOG_SIZE = 2 * 1024 * 1024 # in Bytes
58
+ MAX_LOG_FILES = 20
59
+ file_handler = RotatingFileHandler(
60
+ file_location, mode="a", maxBytes=SINGLE_LOG_SIZE, backupCount=MAX_LOG_FILES
61
+ )
62
+ file_handler.setFormatter(file_log_formatter)
63
+ file_handler.setLevel(log_level)
64
+ logger.addHandler(file_handler)
65
+
66
+ return logger
67
+
68
+
69
+ def evaluate_text(
70
+ path,
71
+ dnikud_model,
72
+ tokenizer_tavbert,
73
+ logger,
74
+ plots_folder=None,
75
+ batch_size=BATCH_SIZE,
76
+ ):
77
+ path_name = os.path.basename(path)
78
+
79
+ msg = f"evaluate text: {path_name} on D-nikud Model"
80
+ logger.debug(msg)
81
+
82
+ if os.path.isfile(path):
83
+ dataset = NikudDataset(
84
+ tokenizer_tavbert, file=path, logger=logger, max_length=MAX_LENGTH_SEN
85
+ )
86
+ elif os.path.isdir(path):
87
+ dataset = NikudDataset(
88
+ tokenizer_tavbert, folder=path, logger=logger, max_length=MAX_LENGTH_SEN
89
+ )
90
+ else:
91
+ raise Exception("input path doesnt exist")
92
+
93
+ dataset.prepare_data(name="evaluate")
94
+ mtb_dl = torch.utils.data.DataLoader(dataset.prepered_data, batch_size=batch_size)
95
+
96
+ word_level_correct, letter_level_correct_dev = evaluate(
97
+ dnikud_model, mtb_dl, plots_folder, device=DEVICE
98
+ )
99
+
100
+ msg = (
101
+ f"Dnikud Model\n{path_name} evaluate\nLetter level accuracy:{letter_level_correct_dev}\n"
102
+ f"Word level accuracy: {word_level_correct}"
103
+ )
104
+ logger.debug(msg)
105
+
106
+
107
+ def predict_text(
108
+ text_file,
109
+ tokenizer_tavbert,
110
+ output_file,
111
+ logger,
112
+ dnikud_model,
113
+ compare_nakdimon=False,
114
+ ):
115
+ dataset = NikudDataset(
116
+ tokenizer_tavbert, file=text_file, logger=logger, max_length=MAX_LENGTH_SEN
117
+ )
118
+
119
+ dataset.prepare_data(name="prediction")
120
+ mtb_prediction_dl = torch.utils.data.DataLoader(
121
+ dataset.prepered_data, batch_size=BATCH_SIZE
122
+ )
123
+ all_labels = predict(dnikud_model, mtb_prediction_dl, DEVICE)
124
+ text_data_with_labels = dataset.back_2_text(labels=all_labels)
125
+
126
+ if output_file is None:
127
+ for line in text_data_with_labels:
128
+ print(line)
129
+ else:
130
+ with open(output_file, "w", encoding="utf-8") as f:
131
+ if compare_nakdimon:
132
+ f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
133
+ else:
134
+ f.write(text_data_with_labels)
135
+
136
+
137
+ def predict_folder(
138
+ folder,
139
+ output_folder,
140
+ logger,
141
+ tokenizer_tavbert,
142
+ dnikud_model,
143
+ compare_nakdimon=False,
144
+ ):
145
+ create_missing_folders(output_folder)
146
+
147
+ for filename in os.listdir(folder):
148
+ file_path = os.path.join(folder, filename)
149
+
150
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
151
+ output_file = os.path.join(output_folder, filename)
152
+ predict_text(
153
+ file_path,
154
+ output_file=output_file,
155
+ logger=logger,
156
+ tokenizer_tavbert=tokenizer_tavbert,
157
+ dnikud_model=dnikud_model,
158
+ compare_nakdimon=compare_nakdimon,
159
+ )
160
+ elif (
161
+ os.path.isdir(file_path) and filename != ".git" and filename != "README.md"
162
+ ):
163
+ sub_folder = file_path
164
+ sub_folder_output = os.path.join(output_folder, filename)
165
+ predict_folder(
166
+ sub_folder,
167
+ sub_folder_output,
168
+ logger,
169
+ tokenizer_tavbert,
170
+ dnikud_model,
171
+ compare_nakdimon=compare_nakdimon,
172
+ )
173
+
174
+
175
+ def update_compare_folder(folder, output_folder):
176
+ create_missing_folders(output_folder)
177
+
178
+ for filename in os.listdir(folder):
179
+ file_path = os.path.join(folder, filename)
180
+
181
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
182
+ output_file = os.path.join(output_folder, filename)
183
+ with open(file_path, "r", encoding="utf-8") as f:
184
+ text_data_with_labels = f.read()
185
+ with open(output_file, "w", encoding="utf-8") as f:
186
+ f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
187
+ elif os.path.isdir(file_path) and filename != ".git":
188
+ sub_folder = file_path
189
+ sub_folder_output = os.path.join(output_folder, filename)
190
+ update_compare_folder(sub_folder, sub_folder_output)
191
+
192
+
193
+ def check_files_excepted(folder):
194
+ for filename in os.listdir(folder):
195
+ file_path = os.path.join(folder, filename)
196
+
197
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
198
+ try:
199
+ x = NikudDataset(None, file=file_path)
200
+ except:
201
+ print(f"failed in file: {filename}")
202
+ elif os.path.isdir(file_path) and filename != ".git":
203
+ check_files_excepted(file_path)
204
+
205
+
206
+ def do_predict(
207
+ input_path, output_path, tokenizer_tavbert, logger, dnikud_model, compare_nakdimon
208
+ ):
209
+ if os.path.isdir(input_path):
210
+ predict_folder(
211
+ input_path,
212
+ output_path,
213
+ logger,
214
+ tokenizer_tavbert,
215
+ dnikud_model,
216
+ compare_nakdimon=compare_nakdimon,
217
+ )
218
+ elif os.path.isfile(input_path):
219
+ predict_text(
220
+ input_path,
221
+ output_file=output_path,
222
+ logger=logger,
223
+ tokenizer_tavbert=tokenizer_tavbert,
224
+ dnikud_model=dnikud_model,
225
+ compare_nakdimon=compare_nakdimon,
226
+ )
227
+ else:
228
+ raise Exception("Input file not exist")
229
+
230
+
231
+ def evaluate_folder(folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder):
232
+ msg = f"evaluate sub folder: {folder_path}"
233
+ logger.info(msg)
234
+
235
+ evaluate_text(
236
+ folder_path,
237
+ dnikud_model=dnikud_model,
238
+ tokenizer_tavbert=tokenizer_tavbert,
239
+ logger=logger,
240
+ plots_folder=plots_folder,
241
+ batch_size=BATCH_SIZE,
242
+ )
243
+
244
+ msg = f"\n***************************************\n"
245
+ logger.info(msg)
246
+
247
+ for sub_folder_name in os.listdir(folder_path):
248
+ sub_folder_path = os.path.join(folder_path, sub_folder_name)
249
+
250
+ if (
251
+ not os.path.isdir(sub_folder_path)
252
+ or sub_folder_path == ".git"
253
+ or "not_use" in sub_folder_path
254
+ or "NakdanResults" in sub_folder_path
255
+ ):
256
+ continue
257
+
258
+ evaluate_folder(
259
+ sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
260
+ )
261
+
262
+
263
+ def do_evaluate(
264
+ input_path,
265
+ logger,
266
+ dnikud_model,
267
+ tokenizer_tavbert,
268
+ plots_folder,
269
+ eval_sub_folders=False,
270
+ ):
271
+ msg = f"evaluate all_data: {input_path}"
272
+ logger.info(msg)
273
+
274
+ evaluate_text(
275
+ input_path,
276
+ dnikud_model=dnikud_model,
277
+ tokenizer_tavbert=tokenizer_tavbert,
278
+ logger=logger,
279
+ plots_folder=plots_folder,
280
+ batch_size=BATCH_SIZE,
281
+ )
282
+
283
+ msg = f"\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n"
284
+ logger.info(msg)
285
+
286
+ if eval_sub_folders:
287
+ for sub_folder_name in os.listdir(input_path):
288
+ sub_folder_path = os.path.join(input_path, sub_folder_name)
289
+
290
+ if (
291
+ not os.path.isdir(sub_folder_path)
292
+ or sub_folder_path == ".git"
293
+ or "not_use" in sub_folder_path
294
+ or "NakdanResults" in sub_folder_path
295
+ ):
296
+ continue
297
+
298
+ evaluate_folder(
299
+ sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
300
+ )
301
+
302
+
303
+ def do_train(
304
+ logger,
305
+ plots_folder,
306
+ dir_model_config,
307
+ tokenizer_tavbert,
308
+ dnikud_model,
309
+ output_trained_model_dir,
310
+ data_folder,
311
+ n_epochs,
312
+ checkpoints_frequency,
313
+ learning_rate,
314
+ batch_size,
315
+ ):
316
+ msg = "Loading data..."
317
+ logger.debug(msg)
318
+
319
+ dataset_train = NikudDataset(
320
+ tokenizer_tavbert,
321
+ folder=os.path.join(data_folder, "train"),
322
+ logger=logger,
323
+ max_length=MAX_LENGTH_SEN,
324
+ is_train=True,
325
+ )
326
+ dataset_dev = NikudDataset(
327
+ tokenizer=tokenizer_tavbert,
328
+ folder=os.path.join(data_folder, "dev"),
329
+ logger=logger,
330
+ max_length=dataset_train.max_length,
331
+ is_train=True,
332
+ )
333
+ dataset_test = NikudDataset(
334
+ tokenizer=tokenizer_tavbert,
335
+ folder=os.path.join(data_folder, "test"),
336
+ logger=logger,
337
+ max_length=dataset_train.max_length,
338
+ is_train=True,
339
+ )
340
+
341
+ dataset_train.show_data_labels(plots_folder=plots_folder)
342
+
343
+ msg = f"Max length of data: {dataset_train.max_length}"
344
+ logger.debug(msg)
345
+
346
+ msg = (
347
+ f"Num rows in train data: {len(dataset_train.data)}, "
348
+ f"Num rows in dev data: {len(dataset_dev.data)}, "
349
+ f"Num rows in test data: {len(dataset_test.data)}"
350
+ )
351
+ logger.debug(msg)
352
+
353
+ msg = "Loading tokenizer and prepare data..."
354
+ logger.debug(msg)
355
+
356
+ dataset_train.prepare_data(name="train")
357
+ dataset_dev.prepare_data(name="dev")
358
+ dataset_test.prepare_data(name="test")
359
+
360
+ mtb_train_dl = torch.utils.data.DataLoader(
361
+ dataset_train.prepered_data, batch_size=batch_size
362
+ )
363
+ mtb_dev_dl = torch.utils.data.DataLoader(
364
+ dataset_dev.prepered_data, batch_size=batch_size
365
+ )
366
+
367
+ if not os.path.isfile(dir_model_config):
368
+ our_model_config = ModelConfig(dataset_train.max_length)
369
+ our_model_config.save_to_file(dir_model_config)
370
+
371
+ optimizer = torch.optim.Adam(dnikud_model.parameters(), lr=learning_rate)
372
+
373
+ msg = "training..."
374
+ logger.debug(msg)
375
+
376
+ criterion_nikud = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
377
+ DEVICE
378
+ )
379
+ criterion_dagesh = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
380
+ DEVICE
381
+ )
382
+ criterion_sin = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(DEVICE)
383
+
384
+ training_params = {
385
+ "n_epochs": n_epochs,
386
+ "checkpoints_frequency": checkpoints_frequency,
387
+ }
388
+ (
389
+ best_model_details,
390
+ best_accuracy,
391
+ epochs_loss_train_values,
392
+ steps_loss_train_values,
393
+ loss_dev_values,
394
+ accuracy_dev_values,
395
+ ) = training(
396
+ dnikud_model,
397
+ mtb_train_dl,
398
+ mtb_dev_dl,
399
+ criterion_nikud,
400
+ criterion_dagesh,
401
+ criterion_sin,
402
+ training_params,
403
+ logger,
404
+ output_trained_model_dir,
405
+ optimizer,
406
+ device=DEVICE,
407
+ )
408
+
409
+ generate_plot_by_nikud_dagesh_sin_dict(
410
+ epochs_loss_train_values, "Train epochs loss", "Loss", plots_folder
411
+ )
412
+ generate_plot_by_nikud_dagesh_sin_dict(
413
+ steps_loss_train_values, "Train steps loss", "Loss", plots_folder
414
+ )
415
+ generate_plot_by_nikud_dagesh_sin_dict(
416
+ loss_dev_values, "Dev epochs loss", "Loss", plots_folder
417
+ )
418
+ generate_plot_by_nikud_dagesh_sin_dict(
419
+ accuracy_dev_values, "Dev accuracy", "Accuracy", plots_folder
420
+ )
421
+ generate_word_and_letter_accuracy_plot(
422
+ accuracy_dev_values, "Accuracy", plots_folder
423
+ )
424
+
425
+ msg = "Done"
426
+ logger.info(msg)
427
+
428
+
429
+ if __name__ == "__main__":
430
+ tokenizer_tavbert = AutoTokenizer.from_pretrained("tau/tavbert-he")
431
+
432
+ parser = argparse.ArgumentParser(
433
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
434
+ description="""Predict D-nikud""",
435
+ )
436
+ parser.add_argument(
437
+ "-l",
438
+ "--log",
439
+ dest="log_level",
440
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
441
+ default="DEBUG",
442
+ help="Set the logging level",
443
+ )
444
+ parser.add_argument(
445
+ "-m",
446
+ "--output_model_dir",
447
+ type=str,
448
+ default="models",
449
+ help="save directory for model",
450
+ )
451
+ subparsers = parser.add_subparsers(
452
+ help="sub-command help", dest="command", required=True
453
+ )
454
+
455
+ parser_predict = subparsers.add_parser("predict", help="diacritize a text files ")
456
+ parser_predict.add_argument("input_path", help="input file or folder")
457
+ parser_predict.add_argument("output_path", help="output file")
458
+ parser_predict.add_argument(
459
+ "-ptmp",
460
+ "--pretrain_model_path",
461
+ type=str,
462
+ default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
463
+ help="pre-train model path - use only if you want to use trained model weights",
464
+ )
465
+ parser_predict.add_argument(
466
+ "-c",
467
+ "--compare",
468
+ dest="compare_nakdimon",
469
+ default=False,
470
+ help="predict text for comparing with Nakdimon",
471
+ )
472
+ parser_predict.set_defaults(func=do_predict)
473
+
474
+ parser_evaluate = subparsers.add_parser("evaluate", help="evaluate D-nikud")
475
+ parser_evaluate.add_argument("input_path", help="input file or folder")
476
+ parser_evaluate.add_argument(
477
+ "-ptmp",
478
+ "--pretrain_model_path",
479
+ type=str,
480
+ default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
481
+ help="pre-train model path - use only if you want to use trained model weights",
482
+ )
483
+ parser_evaluate.add_argument(
484
+ "-df",
485
+ "--plots_folder",
486
+ dest="plots_folder",
487
+ default=os.path.join(Path(__file__).parent, "plots"),
488
+ help="set the debug folder",
489
+ )
490
+ parser_evaluate.add_argument(
491
+ "-es",
492
+ "--eval_sub_folders",
493
+ dest="eval_sub_folders",
494
+ default=False,
495
+ help="accuracy calculation includes the evaluation of sub-folders "
496
+ "within the input_path folder, providing independent assessments "
497
+ "for each subfolder.",
498
+ )
499
+ parser_evaluate.set_defaults(func=do_evaluate)
500
+
501
+ # train --n_epochs 20
502
+
503
+ parser_train = subparsers.add_parser("train", help="train D-nikud")
504
+ parser_train.add_argument(
505
+ "-ptmp",
506
+ "--pretrain_model_path",
507
+ type=str,
508
+ default=None,
509
+ help="pre-train model path - use only if you want to use trained model weights",
510
+ )
511
+ parser_train.add_argument(
512
+ "--learning_rate", type=float, default=0.001, help="Learning rate"
513
+ )
514
+ parser_train.add_argument("--batch_size", type=int, default=32, help="batch_size")
515
+ parser_train.add_argument(
516
+ "--n_epochs", type=int, default=10, help="number of epochs"
517
+ )
518
+ parser_train.add_argument(
519
+ "--data_folder",
520
+ dest="data_folder",
521
+ default=os.path.join(Path(__file__).parent, "data"),
522
+ help="Set the debug folder",
523
+ )
524
+ parser_train.add_argument(
525
+ "--checkpoints_frequency",
526
+ type=int,
527
+ default=1,
528
+ help="checkpoints frequency for save the model",
529
+ )
530
+ parser_train.add_argument(
531
+ "-df",
532
+ "--plots_folder",
533
+ dest="plots_folder",
534
+ default=os.path.join(Path(__file__).parent, "plots"),
535
+ help="Set the debug folder",
536
+ )
537
+ parser_train.set_defaults(func=do_train)
538
+
539
+ args = parser.parse_args()
540
+ kwargs = vars(args).copy()
541
+ date_time = datetime.now().strftime("%d_%m_%y__%H_%M")
542
+ logger = get_logger(kwargs["log_level"], args.command, date_time)
543
+
544
+ del kwargs["log_level"]
545
+
546
+ kwargs["tokenizer_tavbert"] = tokenizer_tavbert
547
+ kwargs["logger"] = logger
548
+
549
+ msg = "Loading model..."
550
+ logger.debug(msg)
551
+
552
+ if args.command in ["evaluate", "predict"] or (
553
+ args.command == "train" and args.pretrain_model_path is not None
554
+ ):
555
+ dir_model_config = os.path.join("models", "config.yml")
556
+ config = ModelConfig.load_from_file(dir_model_config)
557
+
558
+ dnikud_model = DNikudModel(
559
+ config,
560
+ len(Nikud.label_2_id["nikud"]),
561
+ len(Nikud.label_2_id["dagesh"]),
562
+ len(Nikud.label_2_id["sin"]),
563
+ device=DEVICE,
564
+ ).to(DEVICE)
565
+ state_dict_model = dnikud_model.state_dict()
566
+ state_dict_model.update(torch.load(args.pretrain_model_path))
567
+ dnikud_model.load_state_dict(state_dict_model)
568
+ else:
569
+ base_model_name = "tau/tavbert-he"
570
+ config = AutoConfig.from_pretrained(base_model_name)
571
+ dnikud_model = DNikudModel(
572
+ config,
573
+ len(Nikud.label_2_id["nikud"]),
574
+ len(Nikud.label_2_id["dagesh"]),
575
+ len(Nikud.label_2_id["sin"]),
576
+ pretrain_model=base_model_name,
577
+ device=DEVICE,
578
+ ).to(DEVICE)
579
+
580
+ if args.command == "train":
581
+ output_trained_model_dir = os.path.join(
582
+ kwargs["output_model_dir"], "latest", f"output_models_{date_time}"
583
+ )
584
+ create_missing_folders(output_trained_model_dir)
585
+ dir_model_config = os.path.join(kwargs["output_model_dir"], "config.yml")
586
+ kwargs["dir_model_config"] = dir_model_config
587
+ kwargs["output_trained_model_dir"] = output_trained_model_dir
588
+ del kwargs["pretrain_model_path"]
589
+ del kwargs["output_model_dir"]
590
+ kwargs["dnikud_model"] = dnikud_model
591
+
592
+ del kwargs["command"]
593
+ del kwargs["func"]
594
+ args.func(**kwargs)
595
+
596
+ sys.exit(0)
models/config.yml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _commit_hash: 41265b09a862144b2517afdfd46da4388f1380df
2
+ _name_or_path: tau/tavbert-he
3
+ add_cross_attention: false
4
+ architectures:
5
+ - RobertaForMaskedLM
6
+ attention_probs_dropout_prob: 0.1
7
+ bad_words_ids: null
8
+ begin_suppress_tokens: null
9
+ bos_token_id: 0
10
+ chunk_size_feed_forward: 0
11
+ classifier_dropout: null
12
+ cross_attention_hidden_size: null
13
+ decoder_start_token_id: null
14
+ diversity_penalty: 0.0
15
+ do_sample: false
16
+ early_stopping: false
17
+ encoder_no_repeat_ngram_size: 0
18
+ eos_token_id: 2
19
+ exponential_decay_length_penalty: null
20
+ finetuning_task: null
21
+ forced_bos_token_id: null
22
+ forced_eos_token_id: null
23
+ gradient_checkpointing: false
24
+ hidden_act: gelu
25
+ hidden_dropout_prob: 0.1
26
+ hidden_size: 768
27
+ id2label:
28
+ 0: LABEL_0
29
+ 1: LABEL_1
30
+ initializer_range: 0.02
31
+ intermediate_size: 3072
32
+ is_decoder: false
33
+ is_encoder_decoder: false
34
+ label2id:
35
+ LABEL_0: 0
36
+ LABEL_1: 1
37
+ layer_norm_eps: 1.0e-05
38
+ length_penalty: 1.0
39
+ max_length: 512
40
+ max_position_embeddings: 2050
41
+ min_length: 0
42
+ model_type: roberta
43
+ no_repeat_ngram_size: 0
44
+ num_attention_heads: 12
45
+ num_beam_groups: 1
46
+ num_beams: 1
47
+ num_hidden_layers: 12
48
+ num_return_sequences: 1
49
+ output_attentions: false
50
+ output_hidden_states: false
51
+ output_scores: false
52
+ pad_token_id: 1
53
+ position_embedding_type: absolute
54
+ prefix: null
55
+ problem_type: null
56
+ pruned_heads: {}
57
+ remove_invalid_values: false
58
+ repetition_penalty: 1.0
59
+ return_dict: true
60
+ return_dict_in_generate: false
61
+ sep_token_id: null
62
+ suppress_tokens: null
63
+ task_specific_params: null
64
+ temperature: 1.0
65
+ tf_legacy_loss: false
66
+ tie_encoder_decoder: false
67
+ tie_word_embeddings: true
68
+ tokenizer_class: null
69
+ top_k: 50
70
+ top_p: 1.0
71
+ torch_dtype: null
72
+ torchscript: false
73
+ transformers_version: 4.6.0.dev0
74
+ type_vocab_size: 2
75
+ typical_p: 1.0
76
+ use_bfloat16: false
77
+ use_cache: true
78
+ vocab_size: 345
src/models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import subprocess
3
+ import yaml
4
+
5
+ # ML
6
+ import torch.nn as nn
7
+ from transformers import AutoConfig, RobertaForMaskedLM, PretrainedConfig
8
+
9
+
10
+ class DNikudModel(nn.Module):
11
+ def __init__(self, config, nikud_size, dagesh_size, sin_size, pretrain_model=None, device='cpu'):
12
+ super(DNikudModel, self).__init__()
13
+
14
+ if pretrain_model is not None:
15
+ model_base = RobertaForMaskedLM.from_pretrained(pretrain_model).to(device)
16
+ else:
17
+ model_base = RobertaForMaskedLM(config=config).to(device)
18
+
19
+ self.model = model_base.roberta
20
+ for name, param in self.model.named_parameters():
21
+ param.requires_grad = False
22
+
23
+ self.lstm1 = nn.LSTM(config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
24
+ self.lstm2 = nn.LSTM(2 * config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
25
+ self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
26
+ self.out_n = nn.Linear(config.hidden_size, nikud_size)
27
+ self.out_d = nn.Linear(config.hidden_size, dagesh_size)
28
+ self.out_s = nn.Linear(config.hidden_size, sin_size)
29
+
30
+ def forward(self, input_ids, attention_mask):
31
+ last_hidden_state = self.model(input_ids, attention_mask=attention_mask).last_hidden_state
32
+ lstm1, _ = self.lstm1(last_hidden_state)
33
+ lstm2, _ = self.lstm2(lstm1)
34
+ dense = self.dense(lstm2)
35
+
36
+ nikud = self.out_n(dense)
37
+ dagesh = self.out_d(dense)
38
+ sin = self.out_s(dense)
39
+
40
+ return nikud, dagesh, sin
41
+
42
+
43
+ def get_git_commit_hash():
44
+ try:
45
+ commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
46
+ return commit_hash
47
+ except subprocess.CalledProcessError:
48
+ # This will be raised if you're not in a Git repository
49
+ print("Not inside a Git repository!")
50
+ return None
51
+
52
+
53
+ class ModelConfig(PretrainedConfig):
54
+ def __init__(self, max_length=None, dict=None):
55
+ super(ModelConfig, self).__init__()
56
+ if dict is None:
57
+ self.__dict__.update(AutoConfig.from_pretrained("tau/tavbert-he").__dict__)
58
+ self.max_length = max_length
59
+ self._commit_hash = get_git_commit_hash()
60
+ else:
61
+ self.__dict__.update(dict)
62
+
63
+ def print(self):
64
+ print(self.__dict__)
65
+
66
+ def save_to_file(self, file_path):
67
+ with open(file_path, "w") as yaml_file:
68
+ yaml.dump(self.__dict__, yaml_file, default_flow_style=False)
69
+
70
+ @classmethod
71
+ def load_from_file(cls, file_path):
72
+ with open(file_path, "r") as yaml_file:
73
+ config_dict = yaml.safe_load(yaml_file)
74
+ return cls(dict=config_dict)
src/models_utils.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import json
3
+ import os
4
+
5
+ # ML
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ # visual
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from sklearn.metrics import confusion_matrix
14
+ from tqdm import tqdm
15
+
16
+ from src.running_params import DEBUG_MODE
17
+ from src.utiles_data import Nikud, create_missing_folders
18
+
19
+ CLASSES_LIST = ["nikud", "dagesh", "sin"]
20
+
21
+
22
+ def calc_num_correct_words(input, letter_correct_mask):
23
+ SPACE_TOKEN = 104
24
+ START_SENTENCE_TOKEN = 1
25
+ END_SENTENCE_TOKEN = 2
26
+
27
+ correct_words_count = 0
28
+ words_count = 0
29
+ for index in range(input.shape[0]):
30
+ input[index][np.where(input[index] == SPACE_TOKEN)[0]] = 0
31
+ input[index][np.where(input[index] == START_SENTENCE_TOKEN)[0]] = 0
32
+ input[index][np.where(input[index] == END_SENTENCE_TOKEN)[0]] = 0
33
+ words_end_index = np.concatenate(
34
+ (np.array([-1]), np.where(input[index] == 0)[0])
35
+ )
36
+ is_correct_words_array = [
37
+ bool(
38
+ letter_correct_mask[index][
39
+ list(range((words_end_index[s] + 1), words_end_index[s + 1]))
40
+ ].all()
41
+ )
42
+ for s in range(len(words_end_index) - 1)
43
+ if words_end_index[s + 1] - (words_end_index[s] + 1) > 1
44
+ ]
45
+ correct_words_count += np.array(is_correct_words_array).sum()
46
+ words_count += len(is_correct_words_array)
47
+
48
+ return correct_words_count, words_count
49
+
50
+
51
+ def predict(model, data_loader, device="cpu"):
52
+ model.to(device)
53
+
54
+ all_labels = None
55
+ with torch.no_grad():
56
+ for index_data, data in enumerate(data_loader):
57
+ (inputs, attention_mask, labels_demo) = data
58
+ inputs = inputs.to(device)
59
+ attention_mask = attention_mask.to(device)
60
+ labels_demo = labels_demo.to(device)
61
+
62
+ mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
63
+ mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
64
+ mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
65
+
66
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
67
+
68
+ pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
69
+ inputs.shape[0], inputs.shape[1], 1
70
+ )
71
+ pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
72
+ inputs.shape[0], inputs.shape[1], 1
73
+ )
74
+ pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
75
+ inputs.shape[0], inputs.shape[1], 1
76
+ )
77
+
78
+ pred_nikud[mask_cant_be_nikud] = -1
79
+ pred_dagesh[mask_cant_be_dagesh] = -1
80
+ pred_sin[mask_cant_be_sin] = -1
81
+
82
+ pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
83
+
84
+ if all_labels is None:
85
+ all_labels = pred_labels
86
+ else:
87
+ all_labels = np.concatenate((all_labels, pred_labels), axis=0)
88
+
89
+ return all_labels
90
+
91
+
92
+ def predict_single(model, data, device="cpu"):
93
+ # model.to(device)
94
+
95
+ all_labels = None
96
+ with torch.no_grad():
97
+ inputs = data["input_ids"].to(device)
98
+ attention_mask = data["attention_mask"].to(device)
99
+
100
+ # mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
101
+ # mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
102
+ # mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
103
+
104
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
105
+ print(nikud_probs, dagesh_probs, sin_probs)
106
+
107
+ pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
108
+ inputs.shape[0], inputs.shape[1], 1
109
+ )
110
+ pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
111
+ inputs.shape[0], inputs.shape[1], 1
112
+ )
113
+ pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
114
+ inputs.shape[0], inputs.shape[1], 1
115
+ )
116
+
117
+ # pred_nikud[mask_cant_be_nikud] = -1
118
+ # pred_dagesh[mask_cant_be_dagesh] = -1
119
+ # pred_sin[mask_cant_be_sin] = -1
120
+ # print(pred_nikud, pred_dagesh, pred_sin)
121
+ pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
122
+ print(pred_labels)
123
+ if all_labels is None:
124
+ all_labels = pred_labels
125
+ else:
126
+ all_labels = np.concatenate((all_labels, pred_labels), axis=0)
127
+
128
+ return all_labels
129
+
130
+
131
+ def training(
132
+ model,
133
+ train_loader,
134
+ dev_loader,
135
+ criterion_nikud,
136
+ criterion_dagesh,
137
+ criterion_sin,
138
+ training_params,
139
+ logger,
140
+ output_model_path,
141
+ optimizer,
142
+ device="cpu",
143
+ ):
144
+ max_length = None
145
+ best_accuracy = 0.0
146
+
147
+ logger.info(f"start training with training_params: {training_params}")
148
+ model = model.to(device)
149
+
150
+ criteria = {
151
+ "nikud": criterion_nikud.to(device),
152
+ "dagesh": criterion_dagesh.to(device),
153
+ "sin": criterion_sin.to(device),
154
+ }
155
+
156
+ output_checkpoints_path = os.path.join(output_model_path, "checkpoints")
157
+ create_missing_folders(output_checkpoints_path)
158
+
159
+ train_steps_loss_values = {"nikud": [], "dagesh": [], "sin": []}
160
+ train_epochs_loss_values = {"nikud": [], "dagesh": [], "sin": []}
161
+ dev_loss_values = {"nikud": [], "dagesh": [], "sin": []}
162
+ dev_accuracy_values = {
163
+ "nikud": [],
164
+ "dagesh": [],
165
+ "sin": [],
166
+ "all_nikud_letter": [],
167
+ "all_nikud_word": [],
168
+ }
169
+
170
+ for epoch in tqdm(range(training_params["n_epochs"]), desc="Training"):
171
+ model.train()
172
+ train_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
173
+ relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
174
+
175
+ for index_data, data in enumerate(train_loader):
176
+ (inputs, attention_mask, labels) = data
177
+
178
+ if max_length is None:
179
+ max_length = labels.shape[1]
180
+
181
+ inputs = inputs.to(device)
182
+ attention_mask = attention_mask.to(device)
183
+ labels = labels.to(device)
184
+
185
+ optimizer.zero_grad()
186
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
187
+
188
+ for i, (probs, class_name) in enumerate(
189
+ zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
190
+ ):
191
+ reshaped_tensor = (
192
+ torch.transpose(probs, 1, 2)
193
+ .contiguous()
194
+ .view(probs.shape[0], probs.shape[2], probs.shape[1])
195
+ )
196
+ loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(device)
197
+
198
+ num_relevant = (labels[:, :, i] != -1).sum()
199
+ train_loss[class_name] += loss.item() * num_relevant
200
+ relevant_count[class_name] += num_relevant
201
+
202
+ loss.backward(retain_graph=True)
203
+
204
+ for i, class_name in enumerate(CLASSES_LIST):
205
+ train_steps_loss_values[class_name].append(
206
+ float(train_loss[class_name] / relevant_count[class_name])
207
+ )
208
+
209
+ optimizer.step()
210
+ if (index_data + 1) % 100 == 0:
211
+ msg = f"epoch: {epoch} , index_data: {index_data + 1}\n"
212
+ for i, class_name in enumerate(CLASSES_LIST):
213
+ msg += f"mean loss train {class_name}: {float(train_loss[class_name] / relevant_count[class_name])}, "
214
+
215
+ logger.debug(msg[:-2])
216
+
217
+ for i, class_name in enumerate(CLASSES_LIST):
218
+ train_epochs_loss_values[class_name].append(
219
+ float(train_loss[class_name] / relevant_count[class_name])
220
+ )
221
+
222
+ for class_name in train_loss.keys():
223
+ train_loss[class_name] /= relevant_count[class_name]
224
+
225
+ msg = f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
226
+ for i, class_name in enumerate(CLASSES_LIST):
227
+ msg += f"mean loss train {class_name}: {train_loss[class_name]}, "
228
+ logger.debug(msg[:-2])
229
+
230
+ model.eval()
231
+ dev_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
232
+ dev_accuracy = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
233
+ relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
234
+ correct_preds = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
235
+ un_masks = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
236
+ predictions = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
237
+ labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
238
+
239
+ all_nikud_types_correct_preds_letter = 0.0
240
+
241
+ letter_count = 0.0
242
+ correct_words_count = 0.0
243
+ word_count = 0.0
244
+ with torch.no_grad():
245
+ for index_data, data in enumerate(dev_loader):
246
+ (inputs, attention_mask, labels) = data
247
+ inputs = inputs.to(device)
248
+ attention_mask = attention_mask.to(device)
249
+ labels = labels.to(device)
250
+
251
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
252
+
253
+ for i, (probs, class_name) in enumerate(
254
+ zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
255
+ ):
256
+ reshaped_tensor = (
257
+ torch.transpose(probs, 1, 2)
258
+ .contiguous()
259
+ .view(probs.shape[0], probs.shape[2], probs.shape[1])
260
+ )
261
+ loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(
262
+ device
263
+ )
264
+ un_masked = labels[:, :, i] != -1
265
+ num_relevant = un_masked.sum()
266
+ relevant_count[class_name] += num_relevant
267
+ _, preds = torch.max(probs, 2)
268
+ dev_loss[class_name] += loss.item() * num_relevant
269
+ correct_preds[class_name] += torch.sum(
270
+ preds[un_masked] == labels[:, :, i][un_masked]
271
+ )
272
+ un_masks[class_name] = un_masked
273
+ predictions[class_name] = preds
274
+ labels_class[class_name] = labels[:, :, i]
275
+
276
+ un_mask_all_or = torch.logical_or(
277
+ torch.logical_or(un_masks["nikud"], un_masks["dagesh"]),
278
+ un_masks["sin"],
279
+ )
280
+
281
+ correct = {
282
+ class_name: (torch.ones(un_mask_all_or.shape) == 1).to(device)
283
+ for class_name in CLASSES_LIST
284
+ }
285
+
286
+ for i, class_name in enumerate(CLASSES_LIST):
287
+ correct[class_name][un_masks[class_name]] = (
288
+ predictions[class_name][un_masks[class_name]]
289
+ == labels_class[class_name][un_masks[class_name]]
290
+ )
291
+
292
+ letter_correct_mask = torch.logical_and(
293
+ torch.logical_and(correct["sin"], correct["dagesh"]),
294
+ correct["nikud"],
295
+ )
296
+ all_nikud_types_correct_preds_letter += torch.sum(
297
+ letter_correct_mask[un_mask_all_or]
298
+ )
299
+
300
+ letter_correct_mask[~un_mask_all_or] = True
301
+ correct_num, total_words_num = calc_num_correct_words(
302
+ inputs.cpu(), letter_correct_mask
303
+ )
304
+
305
+ word_count += total_words_num
306
+ correct_words_count += correct_num
307
+ letter_count += un_mask_all_or.sum()
308
+
309
+ for class_name in CLASSES_LIST:
310
+ dev_loss[class_name] /= relevant_count[class_name]
311
+ dev_accuracy[class_name] = float(
312
+ correct_preds[class_name].double() / relevant_count[class_name]
313
+ )
314
+
315
+ dev_loss_values[class_name].append(float(dev_loss[class_name]))
316
+ dev_accuracy_values[class_name].append(float(dev_accuracy[class_name]))
317
+
318
+ dev_all_nikud_types_accuracy_letter = float(
319
+ all_nikud_types_correct_preds_letter / letter_count
320
+ )
321
+
322
+ dev_accuracy_values["all_nikud_letter"].append(
323
+ dev_all_nikud_types_accuracy_letter
324
+ )
325
+
326
+ word_all_nikud_accuracy = correct_words_count / word_count
327
+ dev_accuracy_values["all_nikud_word"].append(word_all_nikud_accuracy)
328
+
329
+ msg = (
330
+ f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
331
+ f'mean loss Dev nikud: {train_loss["nikud"]}, '
332
+ f'mean loss Dev dagesh: {train_loss["dagesh"]}, '
333
+ f'mean loss Dev sin: {train_loss["sin"]}, '
334
+ f"Dev all nikud types letter Accuracy: {dev_all_nikud_types_accuracy_letter}, "
335
+ f'Dev nikud letter Accuracy: {dev_accuracy["nikud"]}, '
336
+ f'Dev dagesh letter Accuracy: {dev_accuracy["dagesh"]}, '
337
+ f'Dev sin letter Accuracy: {dev_accuracy["sin"]}, '
338
+ f"Dev word Accuracy: {word_all_nikud_accuracy}"
339
+ )
340
+ logger.debug(msg)
341
+
342
+ save_progress_details(
343
+ dev_accuracy_values,
344
+ train_epochs_loss_values,
345
+ dev_loss_values,
346
+ train_steps_loss_values,
347
+ )
348
+
349
+ if dev_all_nikud_types_accuracy_letter > best_accuracy:
350
+ best_accuracy = dev_all_nikud_types_accuracy_letter
351
+ best_model = {
352
+ "epoch": epoch,
353
+ "model_state_dict": model.state_dict(),
354
+ "optimizer_state_dict": optimizer.state_dict(),
355
+ "loss": loss,
356
+ }
357
+
358
+ if epoch % training_params["checkpoints_frequency"] == 0:
359
+ save_checkpoint_path = os.path.join(
360
+ output_checkpoints_path, f"checkpoint_model_epoch_{epoch + 1}.pth"
361
+ )
362
+ checkpoint = {
363
+ "epoch": epoch,
364
+ "model_state_dict": model.state_dict(),
365
+ "optimizer_state_dict": optimizer.state_dict(),
366
+ "loss": loss,
367
+ }
368
+ torch.save(checkpoint["model_state_dict"], save_checkpoint_path)
369
+
370
+ save_model_path = os.path.join(output_model_path, "best_model.pth")
371
+ torch.save(best_model["model_state_dict"], save_model_path)
372
+ return (
373
+ best_model,
374
+ best_accuracy,
375
+ train_epochs_loss_values,
376
+ train_steps_loss_values,
377
+ dev_loss_values,
378
+ dev_accuracy_values,
379
+ )
380
+
381
+
382
+ def save_progress_details(
383
+ accuracy_dev_values,
384
+ epochs_loss_train_values,
385
+ loss_dev_values,
386
+ steps_loss_train_values,
387
+ ):
388
+ epochs_data_path = "epochs_data"
389
+ create_missing_folders(epochs_data_path)
390
+
391
+ save_dict_as_json(
392
+ steps_loss_train_values, epochs_data_path, "steps_loss_train_values.json"
393
+ )
394
+ save_dict_as_json(
395
+ epochs_loss_train_values, epochs_data_path, "epochs_loss_train_values.json"
396
+ )
397
+ save_dict_as_json(loss_dev_values, epochs_data_path, "loss_dev_values.json")
398
+ save_dict_as_json(accuracy_dev_values, epochs_data_path, "accuracy_dev_values.json")
399
+
400
+
401
+ def save_dict_as_json(dict, file_path, file_name):
402
+ json_data = json.dumps(dict, indent=4)
403
+ with open(os.path.join(file_path, file_name), "w") as json_file:
404
+ json_file.write(json_data)
405
+
406
+
407
+ def evaluate(model, test_data, plots_folder=None, device="cpu"):
408
+ model.to(device)
409
+ model.eval()
410
+
411
+ true_labels = {"nikud": [], "dagesh": [], "sin": []}
412
+ predictions = {"nikud": 0, "dagesh": 0, "sin": 0}
413
+ predicted_labels_2_report = {"nikud": [], "dagesh": [], "sin": []}
414
+ not_masks = {"nikud": 0, "dagesh": 0, "sin": 0}
415
+ correct_preds = {"nikud": 0, "dagesh": 0, "sin": 0}
416
+ relevant_count = {"nikud": 0, "dagesh": 0, "sin": 0}
417
+ labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
418
+
419
+ all_nikud_types_letter_level_correct = 0.0
420
+ nikud_letter_level_correct = 0.0
421
+ dagesh_letter_level_correct = 0.0
422
+ sin_letter_level_correct = 0.0
423
+
424
+ letters_count = 0.0
425
+ words_count = 0.0
426
+ correct_words_count = 0.0
427
+ with torch.no_grad():
428
+ for index_data, data in enumerate(test_data):
429
+ if DEBUG_MODE and index_data > 100:
430
+ break
431
+
432
+ (inputs, attention_mask, labels) = data
433
+
434
+ inputs = inputs.to(device)
435
+ attention_mask = attention_mask.to(device)
436
+ labels = labels.to(device)
437
+
438
+ nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
439
+
440
+ for i, (probs, class_name) in enumerate(
441
+ zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
442
+ ):
443
+ labels_class[class_name] = labels[:, :, i]
444
+ not_masked = labels_class[class_name] != -1
445
+ num_relevant = not_masked.sum()
446
+ relevant_count[class_name] += num_relevant
447
+ _, preds = torch.max(probs, 2)
448
+ correct_preds[class_name] += torch.sum(
449
+ preds[not_masked] == labels_class[class_name][not_masked]
450
+ )
451
+ predictions[class_name] = preds
452
+ not_masks[class_name] = not_masked
453
+
454
+ if len(true_labels[class_name]) == 0:
455
+ true_labels[class_name] = (
456
+ labels_class[class_name][not_masked].cpu().numpy()
457
+ )
458
+ else:
459
+ true_labels[class_name] = np.concatenate(
460
+ (
461
+ true_labels[class_name],
462
+ labels_class[class_name][not_masked].cpu().numpy(),
463
+ )
464
+ )
465
+
466
+ if len(predicted_labels_2_report[class_name]) == 0:
467
+ predicted_labels_2_report[class_name] = (
468
+ preds[not_masked].cpu().numpy()
469
+ )
470
+ else:
471
+ predicted_labels_2_report[class_name] = np.concatenate(
472
+ (
473
+ predicted_labels_2_report[class_name],
474
+ preds[not_masked].cpu().numpy(),
475
+ )
476
+ )
477
+
478
+ not_mask_all_or = torch.logical_or(
479
+ torch.logical_or(not_masks["nikud"], not_masks["dagesh"]),
480
+ not_masks["sin"],
481
+ )
482
+
483
+ correct_nikud = (torch.ones(not_mask_all_or.shape) == 1).to(device)
484
+ correct_dagesh = (torch.ones(not_mask_all_or.shape) == 1).to(device)
485
+ correct_sin = (torch.ones(not_mask_all_or.shape) == 1).to(device)
486
+
487
+ correct_nikud[not_masks["nikud"]] = (
488
+ predictions["nikud"][not_masks["nikud"]]
489
+ == labels_class["nikud"][not_masks["nikud"]]
490
+ )
491
+ correct_dagesh[not_masks["dagesh"]] = (
492
+ predictions["dagesh"][not_masks["dagesh"]]
493
+ == labels_class["dagesh"][not_masks["dagesh"]]
494
+ )
495
+ correct_sin[not_masks["sin"]] = (
496
+ predictions["sin"][not_masks["sin"]]
497
+ == labels_class["sin"][not_masks["sin"]]
498
+ )
499
+
500
+ letter_correct_mask = torch.logical_and(
501
+ torch.logical_and(correct_sin, correct_dagesh), correct_nikud
502
+ )
503
+ all_nikud_types_letter_level_correct += torch.sum(
504
+ letter_correct_mask[not_mask_all_or]
505
+ )
506
+
507
+ letter_correct_mask[~not_mask_all_or] = True
508
+ total_correct_count, total_words_num = calc_num_correct_words(
509
+ inputs.cpu(), letter_correct_mask
510
+ )
511
+
512
+ words_count += total_words_num
513
+ correct_words_count += total_correct_count
514
+
515
+ letters_count += not_mask_all_or.sum()
516
+
517
+ nikud_letter_level_correct += torch.sum(correct_nikud[not_mask_all_or])
518
+ dagesh_letter_level_correct += torch.sum(correct_dagesh[not_mask_all_or])
519
+ sin_letter_level_correct += torch.sum(correct_sin[not_mask_all_or])
520
+
521
+ for i, name in enumerate(CLASSES_LIST):
522
+ index_labels = np.unique(true_labels[name])
523
+ cm = confusion_matrix(
524
+ true_labels[name], predicted_labels_2_report[name], labels=index_labels
525
+ )
526
+
527
+ vowel_label = [Nikud.id_2_label[name][l] for l in index_labels]
528
+ unique_vowels_names = [
529
+ Nikud.sign_2_name[int(vowel)] for vowel in vowel_label if vowel != "WITHOUT"
530
+ ]
531
+ if "WITHOUT" in vowel_label:
532
+ unique_vowels_names += ["WITHOUT"]
533
+ cm_df = pd.DataFrame(cm, index=unique_vowels_names, columns=unique_vowels_names)
534
+
535
+ # Display confusion matrix
536
+ plt.figure(figsize=(10, 8))
537
+ sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d")
538
+ plt.title("Confusion Matrix")
539
+ plt.xlabel("True Label")
540
+ plt.ylabel("Predicted Label")
541
+ if plots_folder is None:
542
+ plt.show()
543
+ else:
544
+ plt.savefig(os.path.join(plots_folder, f"Confusion_Matrix_{name}.jpg"))
545
+
546
+ all_nikud_types_letter_level_correct = (
547
+ all_nikud_types_letter_level_correct / letters_count
548
+ )
549
+ all_nikud_types_word_level_correct = correct_words_count / words_count
550
+ nikud_letter_level_correct = nikud_letter_level_correct / letters_count
551
+ dagesh_letter_level_correct = dagesh_letter_level_correct / letters_count
552
+ sin_letter_level_correct = sin_letter_level_correct / letters_count
553
+ print("\n")
554
+ print(f"nikud_letter_level_correct = {nikud_letter_level_correct}")
555
+ print(f"dagesh_letter_level_correct = {dagesh_letter_level_correct}")
556
+ print(f"sin_letter_level_correct = {sin_letter_level_correct}")
557
+ print(f"word_level_correct = {all_nikud_types_word_level_correct}")
558
+
559
+ return all_nikud_types_word_level_correct, all_nikud_types_letter_level_correct
src/plot_helpers.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import os
3
+
4
+ # visual
5
+ import matplotlib.pyplot as plt
6
+
7
+ cols = ["precision", "recall", "f1-score", "support"]
8
+
9
+
10
+ def generate_plot_by_nikud_dagesh_sin_dict(nikud_dagesh_sin_dict, title, y_axis, plot_folder=None):
11
+ # Create a figure and axis
12
+ plt.figure(figsize=(8, 6))
13
+ plt.title(title)
14
+
15
+ ax = plt.gca()
16
+ indexes = list(range(1, len(nikud_dagesh_sin_dict["nikud"]) + 1))
17
+
18
+ # Plot data series with different colors and labels
19
+ ax.plot(indexes, nikud_dagesh_sin_dict["nikud"], color='blue', label='Nikud')
20
+ ax.plot(indexes, nikud_dagesh_sin_dict["dagesh"], color='green', label='Dagesh')
21
+ ax.plot(indexes, nikud_dagesh_sin_dict["sin"], color='red', label='Sin')
22
+
23
+ # Add legend
24
+ ax.legend()
25
+
26
+ # Set labels and title
27
+ ax.set_xlabel('Epoch')
28
+ ax.set_ylabel(y_axis)
29
+
30
+ if plot_folder is None:
31
+ plt.show()
32
+ else:
33
+ plt.savefig(os.path.join(plot_folder, f'{title.replace(" ", "_")}_plot.jpg'))
34
+
35
+
36
+ def generate_word_and_letter_accuracy_plot(word_and_letter_accuracy_dict, title, plot_folder=None):
37
+ # Create a figure and axis
38
+ plt.figure(figsize=(8, 6))
39
+ plt.title(title)
40
+
41
+ ax = plt.gca()
42
+ indexes = list(range(1, len(word_and_letter_accuracy_dict["all_nikud_letter"]) + 1))
43
+
44
+ # Plot data series with different colors and labels
45
+ ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_letter"], color='blue', label='Letter')
46
+ ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_word"], color='green', label='Word')
47
+
48
+ # Add legend
49
+ ax.legend()
50
+
51
+ # Set labels and title
52
+ ax.set_xlabel("Epoch")
53
+ ax.set_ylabel("Accuracy")
54
+
55
+ if plot_folder is None:
56
+ plt.show()
57
+ else:
58
+ plt.savefig(os.path.join(plot_folder, 'word_and_letter_accuracy_plot.jpg'))
src/running_params.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ DEBUG_MODE = False
2
+ BATCH_SIZE = 32
3
+ MAX_LENGTH_SEN = 1024
src/utiles_data.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general
2
+ import os.path
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import List, Tuple
6
+ from uuid import uuid1
7
+ import re
8
+ import glob2
9
+
10
+ # visual
11
+ import matplotlib
12
+ import matplotlib.pyplot as plt
13
+ from tqdm import tqdm
14
+
15
+ # ML
16
+ import numpy as np
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+
20
+ from src.running_params import DEBUG_MODE, MAX_LENGTH_SEN
21
+
22
+ matplotlib.use("agg")
23
+ unique_key = str(uuid1())
24
+
25
+
26
+ class Nikud:
27
+ """
28
+ 1456 HEBREW POINT SHEVA
29
+ 1457 HEBREW POINT HATAF SEGOL
30
+ 1458 HEBREW POINT HATAF PATAH
31
+ 1459 HEBREW POINT HATAF QAMATS
32
+ 1460 HEBREW POINT HIRIQ
33
+ 1461 HEBREW POINT TSERE
34
+ 1462 HEBREW POINT SEGOL
35
+ 1463 HEBREW POINT PATAH
36
+ 1464 HEBREW POINT QAMATS
37
+ 1465 HEBREW POINT HOLAM
38
+ 1466 HEBREW POINT HOLAM HASER FOR VAV ***EXTENDED***
39
+ 1467 HEBREW POINT QUBUTS
40
+ 1468 HEBREW POINT DAGESH OR MAPIQ
41
+ 1469 HEBREW POINT METEG ***EXTENDED***
42
+ 1470 HEBREW PUNCTUATION MAQAF ***EXTENDED***
43
+ 1471 HEBREW POINT RAFE ***EXTENDED***
44
+ 1472 HEBREW PUNCTUATION PASEQ ***EXTENDED***
45
+ 1473 HEBREW POINT SHIN DOT
46
+ 1474 HEBREW POINT SIN DOT
47
+ """
48
+
49
+ nikud_dict = {
50
+ "SHVA": 1456,
51
+ "REDUCED_SEGOL": 1457,
52
+ "REDUCED_PATAKH": 1458,
53
+ "REDUCED_KAMATZ": 1459,
54
+ "HIRIK": 1460,
55
+ "TZEIRE": 1461,
56
+ "SEGOL": 1462,
57
+ "PATAKH": 1463,
58
+ "KAMATZ": 1464,
59
+ "KAMATZ_KATAN": 1479,
60
+ "HOLAM": 1465,
61
+ "HOLAM HASER VAV": 1466,
62
+ "KUBUTZ": 1467,
63
+ "DAGESH OR SHURUK": 1468,
64
+ "METEG": 1469,
65
+ "PUNCTUATION MAQAF": 1470,
66
+ "RAFE": 1471,
67
+ "PUNCTUATION PASEQ": 1472,
68
+ "SHIN_YEMANIT": 1473,
69
+ "SHIN_SMALIT": 1474,
70
+ }
71
+
72
+ skip_nikud = (
73
+ []
74
+ ) # [nikud_dict["KAMATZ_KATAN"], nikud_dict["HOLAM HASER VAV"], nikud_dict["METEG"], nikud_dict["PUNCTUATION MAQAF"], nikud_dict["PUNCTUATION PASEQ"]]
75
+ sign_2_name = {sign: name for name, sign in nikud_dict.items()}
76
+ sin = [nikud_dict["RAFE"], nikud_dict["SHIN_YEMANIT"], nikud_dict["SHIN_SMALIT"]]
77
+ dagesh = [
78
+ nikud_dict["RAFE"],
79
+ nikud_dict["DAGESH OR SHURUK"],
80
+ ] # note that DAGESH and SHURUK are one and the same
81
+ nikud = []
82
+ for v in nikud_dict.values():
83
+ if v not in sin and v not in skip_nikud:
84
+ nikud.append(v)
85
+ all_nikud_ord = {v for v in nikud_dict.values()}
86
+ all_nikud_chr = {chr(v) for v in nikud_dict.values()}
87
+
88
+ label_2_id = {
89
+ "nikud": {label: i for i, label in enumerate(nikud + ["WITHOUT"])},
90
+ "dagesh": {label: i for i, label in enumerate(dagesh + ["WITHOUT"])},
91
+ "sin": {label: i for i, label in enumerate(sin + ["WITHOUT"])},
92
+ }
93
+ id_2_label = {
94
+ "nikud": {i: label for i, label in enumerate(nikud + ["WITHOUT"])},
95
+ "dagesh": {i: label for i, label in enumerate(dagesh + ["WITHOUT"])},
96
+ "sin": {i: label for i, label in enumerate(sin + ["WITHOUT"])},
97
+ }
98
+
99
+ DAGESH_LETTER = nikud_dict["DAGESH OR SHURUK"]
100
+ RAFE = nikud_dict["RAFE"]
101
+ PAD_OR_IRRELEVANT = -1
102
+
103
+ LEN_NIKUD = len(label_2_id["nikud"])
104
+ LEN_DAGESH = len(label_2_id["dagesh"])
105
+ LEN_SIN = len(label_2_id["sin"])
106
+
107
+ def id_2_char(self, c, class_type):
108
+ if c == -1:
109
+ return ""
110
+
111
+ label = self.id_2_label[class_type][c]
112
+
113
+ if label != "WITHOUT":
114
+ print("Label =", chr(self.id_2_label[class_type][c]))
115
+ return chr(self.id_2_label[class_type][c])
116
+ return ""
117
+
118
+
119
+ class Letters:
120
+ hebrew = [chr(c) for c in range(0x05D0, 0x05EA + 1)]
121
+ VALID_LETTERS = [
122
+ " ",
123
+ "!",
124
+ '"',
125
+ "'",
126
+ "(",
127
+ ")",
128
+ ",",
129
+ "-",
130
+ ".",
131
+ ":",
132
+ ";",
133
+ "?",
134
+ ] + hebrew
135
+ SPECIAL_TOKENS = ["H", "O", "5", "1"]
136
+ ENDINGS_TO_REGULAR = dict(zip("ืšืืŸืฃืฅ", "ื›ืžื ืคืฆ"))
137
+ vocab = VALID_LETTERS + SPECIAL_TOKENS
138
+ vocab_size = len(vocab)
139
+
140
+
141
+ class Letter:
142
+ def __init__(self, letter):
143
+ self.letter = letter
144
+ self.normalized = None
145
+ self.dagesh = None
146
+ self.sin = None
147
+ self.nikud = None
148
+
149
+ def normalize(self, letter):
150
+ if letter in Letters.VALID_LETTERS:
151
+ return letter
152
+ if letter in Letters.ENDINGS_TO_REGULAR:
153
+ return Letters.ENDINGS_TO_REGULAR[letter]
154
+ if letter in ["\n", "\t"]:
155
+ return " "
156
+ if letter in ["โ€’", "โ€“", "โ€”", "โ€•", "โˆ’", "+"]:
157
+ return "-"
158
+ if letter == "[":
159
+ return "("
160
+ if letter == "]":
161
+ return ")"
162
+ if letter in ["ยด", "โ€˜", "โ€™"]:
163
+ return "'"
164
+ if letter in ["โ€œ", "โ€", "ืด"]:
165
+ return '"'
166
+ if letter.isdigit():
167
+ if int(letter) == 1:
168
+ return "1"
169
+ else:
170
+ return "5"
171
+ if letter == "โ€ฆ":
172
+ return ","
173
+ if letter in ["ืฒ", "ืฐ", "ืฑ"]:
174
+ return "H"
175
+ return "O"
176
+
177
+ def can_dagesh(self, letter):
178
+ return letter in ("ื‘ื’ื“ื”ื•ื–ื˜ื™ื›ืœืžื ืกืคืฆืงืฉืช" + "ืšืฃ")
179
+
180
+ def can_sin(self, letter):
181
+ return letter == "ืฉ"
182
+
183
+ def can_nikud(self, letter):
184
+ return letter in ("ืื‘ื’ื“ื”ื•ื–ื—ื˜ื™ื›ืœืžื ืกืขืคืฆืงืจืฉืช" + "ืšืŸ")
185
+
186
+ def get_label_letter(self, labels):
187
+ dagesh_sin_nikud = [
188
+ True if self.can_dagesh(self.letter) else False,
189
+ True if self.can_sin(self.letter) else False,
190
+ True if self.can_nikud(self.letter) else False,
191
+ ]
192
+
193
+ labels_ids = {
194
+ "nikud": Nikud.PAD_OR_IRRELEVANT,
195
+ "dagesh": Nikud.PAD_OR_IRRELEVANT,
196
+ "sin": Nikud.PAD_OR_IRRELEVANT,
197
+ }
198
+
199
+ normalized = self.normalize(self.letter)
200
+
201
+ i = 0
202
+ if Nikud.nikud_dict["PUNCTUATION PASEQ"] in labels:
203
+ labels.remove(Nikud.nikud_dict["PUNCTUATION PASEQ"])
204
+ if Nikud.nikud_dict["PUNCTUATION MAQAF"] in labels:
205
+ labels.remove(Nikud.nikud_dict["PUNCTUATION MAQAF"])
206
+ if Nikud.nikud_dict["HOLAM HASER VAV"] in labels:
207
+ labels.remove(Nikud.nikud_dict["HOLAM HASER VAV"])
208
+ if Nikud.nikud_dict["METEG"] in labels:
209
+ labels.remove(Nikud.nikud_dict["METEG"])
210
+ if Nikud.nikud_dict["KAMATZ_KATAN"] in labels:
211
+ labels[labels.index(Nikud.nikud_dict["KAMATZ_KATAN"])] = Nikud.nikud_dict[
212
+ "KAMATZ"
213
+ ]
214
+ for index, (class_name, group) in enumerate(
215
+ zip(
216
+ ["dagesh", "sin", "nikud"],
217
+ [[Nikud.DAGESH_LETTER], Nikud.sin, Nikud.nikud],
218
+ )
219
+ ):
220
+ # notice - order is important: dagesh then sin and then nikud
221
+ if dagesh_sin_nikud[index]:
222
+ if i < len(labels) and labels[i] in group:
223
+ labels_ids[class_name] = Nikud.label_2_id[class_name][labels[i]]
224
+ i += 1
225
+ else:
226
+ labels_ids[class_name] = Nikud.label_2_id[class_name]["WITHOUT"]
227
+
228
+ if (
229
+ np.array(dagesh_sin_nikud).all()
230
+ and len(labels) == 3
231
+ and labels[0] in Nikud.sin
232
+ ):
233
+ labels_ids["nikud"] = Nikud.label_2_id["nikud"][labels[2]]
234
+ labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
235
+
236
+ if (
237
+ self.can_sin(self.letter)
238
+ and len(labels) == 2
239
+ and labels[1] == Nikud.DAGESH_LETTER
240
+ ):
241
+ labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
242
+ labels_ids["nikud"] = Nikud.label_2_id[class_name]["WITHOUT"]
243
+
244
+ if (
245
+ self.letter == "ื•"
246
+ and labels_ids["dagesh"] == Nikud.DAGESH_LETTER
247
+ and labels_ids["nikud"] == Nikud.label_2_id["nikud"]["WITHOUT"]
248
+ ):
249
+ labels_ids["dagesh"] = Nikud.label_2_id["dagesh"]["WITHOUT"]
250
+ labels_ids["nikud"] = Nikud.DAGESH_LETTER
251
+
252
+ self.normalized = normalized
253
+ self.dagesh = labels_ids["dagesh"]
254
+ self.sin = labels_ids["sin"]
255
+ self.nikud = labels_ids["nikud"]
256
+
257
+ def name_of(self, letter):
258
+ if "ื" <= letter <= "ืช":
259
+ return letter
260
+ if letter == Nikud.DAGESH_LETTER:
261
+ return "ื“ื’ืฉ\ืฉื•ืจื•ืง"
262
+ if letter == Nikud.KAMATZ:
263
+ return "ืงืžืฅ"
264
+ if letter == Nikud.PATAKH:
265
+ return "ืคืชื—"
266
+ if letter == Nikud.TZEIRE:
267
+ return "ืฆื™ืจื”"
268
+ if letter == Nikud.SEGOL:
269
+ return "ืกื’ื•ืœ"
270
+ if letter == Nikud.SHVA:
271
+ return "ืฉื•ื"
272
+ if letter == Nikud.HOLAM:
273
+ return "ื—ื•ืœื"
274
+ if letter == Nikud.KUBUTZ:
275
+ return "ืงื•ื‘ื•ืฅ"
276
+ if letter == Nikud.HIRIK:
277
+ return "ื—ื™ืจื™ืง"
278
+ if letter == Nikud.REDUCED_KAMATZ:
279
+ return "ื—ื˜ืฃ-ืงืžืฅ"
280
+ if letter == Nikud.REDUCED_PATAKH:
281
+ return "ื—ื˜ืฃ-ืคืชื—"
282
+ if letter == Nikud.REDUCED_SEGOL:
283
+ return "ื—ื˜ืฃ-ืกื’ื•ืœ"
284
+ if letter == Nikud.SHIN_SMALIT:
285
+ return "ืฉื™ืŸ-ืฉืžืืœื™ืช"
286
+ if letter == Nikud.SHIN_YEMANIT:
287
+ return "ืฉื™ืŸ-ื™ืžื ื™ืช"
288
+ if letter.isprintable():
289
+ return letter
290
+ return "ืœื ื™ื“ื•ืข ({})".format(hex(ord(letter)))
291
+
292
+
293
+ def text_contains_nikud(text):
294
+ return len(set(text) & Nikud.all_nikud_chr) > 0
295
+
296
+
297
+ def combine_sentences(list_sentences, max_length=0, is_train=False):
298
+ all_new_sentences = []
299
+ new_sen = ""
300
+ index = 0
301
+ while index < len(list_sentences):
302
+ sen = list_sentences[index]
303
+
304
+ if not text_contains_nikud(sen) and (
305
+ "------------------" in sen or sen == "\n"
306
+ ):
307
+ if len(new_sen) > 0:
308
+ all_new_sentences.append(new_sen)
309
+ if not is_train:
310
+ all_new_sentences.append(sen)
311
+ new_sen = ""
312
+ index += 1
313
+ continue
314
+
315
+ if not text_contains_nikud(sen) and is_train:
316
+ index += 1
317
+ continue
318
+
319
+ if len(sen) > max_length:
320
+ update_sen = sen.replace(". ", f". {unique_key}")
321
+ update_sen = update_sen.replace("? ", f"? {unique_key}")
322
+ update_sen = update_sen.replace("! ", f"! {unique_key}")
323
+ update_sen = update_sen.replace("โ€ ", f"โ€ {unique_key}")
324
+ update_sen = update_sen.replace("\t", f"\t{unique_key}")
325
+ part_sentence = update_sen.split(unique_key)
326
+
327
+ good_parts = []
328
+ for p in part_sentence:
329
+ if len(p) < max_length:
330
+ good_parts.append(p)
331
+ else:
332
+ prev = 0
333
+ while prev <= len(p):
334
+ part = p[prev : (prev + max_length)]
335
+ last_space = 0
336
+ if " " in part:
337
+ last_space = part[::-1].index(" ") + 1
338
+ next = prev + max_length - last_space
339
+ part = p[prev:next]
340
+ good_parts.append(part)
341
+ prev = next
342
+ list_sentences = (
343
+ list_sentences[:index] + good_parts + list_sentences[index + 1 :]
344
+ )
345
+ continue
346
+ if new_sen == "":
347
+ new_sen = sen
348
+ elif len(new_sen) + len(sen) < max_length:
349
+ new_sen += sen
350
+ else:
351
+ all_new_sentences.append(new_sen)
352
+ new_sen = sen
353
+
354
+ index += 1
355
+ if len(new_sen) > 0:
356
+ all_new_sentences.append(new_sen)
357
+ return all_new_sentences
358
+
359
+
360
+ class NikudDataset(Dataset):
361
+ def __init__(
362
+ self,
363
+ tokenizer,
364
+ folder=None,
365
+ file=None,
366
+ logger=None,
367
+ max_length=0,
368
+ is_train=False,
369
+ ):
370
+ self.max_length = max_length
371
+ self.tokenizer = tokenizer
372
+ self.is_train = is_train
373
+ if folder is not None:
374
+ self.data, self.origin_data = self.read_data_folder(folder, logger)
375
+ elif file is not None:
376
+ self.data, self.origin_data = self.read_data(file, logger)
377
+ self.prepered_data = None
378
+
379
+ def read_data_folder(self, folder_path: str, logger=None):
380
+ all_files = glob2.glob(f"{folder_path}/**/*.txt", recursive=True)
381
+ msg = f"number of files: " + str(len(all_files))
382
+ if logger:
383
+ logger.debug(msg)
384
+ else:
385
+ print(msg)
386
+ all_data = []
387
+ all_origin_data = []
388
+ if DEBUG_MODE:
389
+ all_files = all_files[0:2]
390
+ for file in all_files:
391
+ if "not_use" in file or "NakdanResults" in file:
392
+ continue
393
+ data, origin_data = self.read_data(file, logger)
394
+ all_data.extend(data)
395
+ all_origin_data.extend(origin_data)
396
+ return all_data, all_origin_data
397
+
398
+ def read_data(self, filepath: str, logger=None) -> List[Tuple[str, list]]:
399
+ msg = f"read file: {filepath}"
400
+ if logger:
401
+ logger.debug(msg)
402
+ else:
403
+ print(msg)
404
+ data = []
405
+ orig_data = []
406
+ with open(filepath, "r", encoding="utf-8") as file:
407
+ file_data = file.read()
408
+ data_list = self.split_text(file_data)
409
+
410
+ for sen in tqdm(data_list, desc=f"Source: {os.path.basename(filepath)}"):
411
+ if sen == "":
412
+ continue
413
+
414
+ labels = []
415
+ text = ""
416
+ text_org = ""
417
+ index = 0
418
+ sentence_length = len(sen)
419
+ while index < sentence_length:
420
+ if (
421
+ ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
422
+ or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
423
+ or ord(sen[index]) == Nikud.nikud_dict["METEG"]
424
+ ):
425
+ index += 1
426
+ continue
427
+
428
+ label = []
429
+ l = Letter(sen[index])
430
+ if not (l.letter not in Nikud.all_nikud_chr):
431
+ if sen[index - 1] == "\n":
432
+ index += 1
433
+ continue
434
+ assert l.letter not in Nikud.all_nikud_chr
435
+ if sen[index] in Letters.hebrew:
436
+ index += 1
437
+ while (
438
+ index < sentence_length
439
+ and ord(sen[index]) in Nikud.all_nikud_ord
440
+ ):
441
+ label.append(ord(sen[index]))
442
+ index += 1
443
+ else:
444
+ index += 1
445
+
446
+ l.get_label_letter(label)
447
+ text += l.normalized
448
+ text_org += l.letter
449
+ labels.append(l)
450
+
451
+ data.append((text, labels))
452
+ orig_data.append(text_org)
453
+
454
+ return data, orig_data
455
+
456
+ def split_text(self, file_data):
457
+ file_data = file_data.replace("\n", f"\n{unique_key}")
458
+ data_list = file_data.split(unique_key)
459
+ data_list = combine_sentences(
460
+ data_list, is_train=self.is_train, max_length=MAX_LENGTH_SEN
461
+ )
462
+ return data_list
463
+
464
+ def show_data_labels(self, plots_folder=None):
465
+ nikud = [
466
+ Nikud.id_2_label["nikud"][label.nikud]
467
+ for _, label_list in self.data
468
+ for label in label_list
469
+ if label.nikud != -1
470
+ ]
471
+ dagesh = [
472
+ Nikud.id_2_label["dagesh"][label.dagesh]
473
+ for _, label_list in self.data
474
+ for label in label_list
475
+ if label.dagesh != -1
476
+ ]
477
+ sin = [
478
+ Nikud.id_2_label["sin"][label.sin]
479
+ for _, label_list in self.data
480
+ for label in label_list
481
+ if label.sin != -1
482
+ ]
483
+
484
+ vowels = nikud + dagesh + sin
485
+ unique_vowels, label_counts = np.unique(vowels, return_counts=True)
486
+ unique_vowels_names = [
487
+ Nikud.sign_2_name[int(vowel)]
488
+ for vowel in unique_vowels
489
+ if vowel != "WITHOUT"
490
+ ] + ["WITHOUT"]
491
+ fig, ax = plt.subplots(figsize=(16, 6))
492
+
493
+ bar_positions = np.arange(len(unique_vowels))
494
+ bar_width = 0.15
495
+ ax.bar(bar_positions, list(label_counts), bar_width)
496
+
497
+ ax.set_title("Distribution of Vowels in dataset")
498
+ ax.set_xlabel("Vowels")
499
+ ax.set_ylabel("Count")
500
+ ax.legend(loc="right", bbox_to_anchor=(1, 0.85))
501
+ ax.set_xticks(bar_positions)
502
+ ax.set_xticklabels(unique_vowels_names, rotation=30, ha="right", fontsize=8)
503
+
504
+ if plots_folder is None:
505
+ plt.show()
506
+ else:
507
+ plt.savefig(os.path.join(plots_folder, "show_data_labels.jpg"))
508
+
509
+ def calc_max_length(self, maximum=MAX_LENGTH_SEN):
510
+ if self.max_length > maximum:
511
+ self.max_length = maximum
512
+ return self.max_length
513
+
514
+ def prepare_data(self, name="train"):
515
+ dataset = []
516
+ for index, (sentence, label) in tqdm(
517
+ enumerate(self.data), desc=f"prepare data {name}"
518
+ ):
519
+ encoded_sequence = self.tokenizer.encode_plus(
520
+ sentence,
521
+ add_special_tokens=True,
522
+ max_length=self.max_length,
523
+ padding="max_length",
524
+ truncation=True,
525
+ return_attention_mask=True,
526
+ return_tensors="pt",
527
+ )
528
+ label_lists = [
529
+ [letter.nikud, letter.dagesh, letter.sin] for letter in label
530
+ ]
531
+ label = torch.tensor(
532
+ [
533
+ [
534
+ Nikud.PAD_OR_IRRELEVANT,
535
+ Nikud.PAD_OR_IRRELEVANT,
536
+ Nikud.PAD_OR_IRRELEVANT,
537
+ ]
538
+ ]
539
+ + label_lists[: (self.max_length - 1)]
540
+ + [
541
+ [
542
+ Nikud.PAD_OR_IRRELEVANT,
543
+ Nikud.PAD_OR_IRRELEVANT,
544
+ Nikud.PAD_OR_IRRELEVANT,
545
+ ]
546
+ for i in range(self.max_length - len(label) - 1)
547
+ ]
548
+ )
549
+
550
+ dataset.append(
551
+ (
552
+ encoded_sequence["input_ids"][0],
553
+ encoded_sequence["attention_mask"][0],
554
+ label,
555
+ )
556
+ )
557
+
558
+ self.prepered_data = dataset
559
+
560
+ def back_2_text(self, labels):
561
+ nikud = Nikud()
562
+ all_text = ""
563
+ for indx_sentance, (input_ids, _, label) in enumerate(self.prepered_data):
564
+ new_line = ""
565
+ for indx_char, c in enumerate(self.origin_data[indx_sentance]):
566
+ new_line += (
567
+ c
568
+ + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 1], "dagesh")
569
+ + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 2], "sin")
570
+ + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 0], "nikud")
571
+ )
572
+ all_text += new_line
573
+ return all_text
574
+
575
+ def __len__(self):
576
+ return self.data.shape[0]
577
+
578
+ def __getitem__(self, idx):
579
+ row = self.data[idx]
580
+
581
+
582
+ def get_sub_folders_paths(main_folder):
583
+ list_paths = []
584
+ for filename in os.listdir(main_folder):
585
+ path = os.path.join(main_folder, filename)
586
+ if os.path.isdir(path) and filename != ".git":
587
+ list_paths.append(path)
588
+ list_paths.extend(get_sub_folders_paths(path))
589
+ return list_paths
590
+
591
+
592
+ def create_missing_folders(folder_path):
593
+ # Check if the folder doesn't exist and create it if needed
594
+ if not os.path.exists(folder_path):
595
+ os.makedirs(folder_path)
596
+
597
+
598
+ def info_folder(folder, num_files, num_hebrew_letters):
599
+ """
600
+ Recursively counts the number of files and the number of Hebrew letters in all subfolders of the given folder path.
601
+
602
+ Args:
603
+ folder (str): The path of the folder to be analyzed.
604
+ num_files (int): The running total of the number of files encountered so far.
605
+ num_hebrew_letters (int): The running total of the number of Hebrew letters encountered so far.
606
+
607
+ Returns:
608
+ Tuple[int, int]: A tuple containing the total number of files and the total number of Hebrew letters.
609
+ """
610
+ for filename in os.listdir(folder):
611
+ file_path = os.path.join(folder, filename)
612
+ if filename.lower().endswith(".txt") and os.path.isfile(file_path):
613
+ num_files += 1
614
+ dataset = NikudDataset(None, file=file_path)
615
+ for line in dataset.data:
616
+ for c in line[0]:
617
+ if c in Letters.hebrew:
618
+ num_hebrew_letters += 1
619
+
620
+ elif os.path.isdir(file_path) and filename != ".git":
621
+ sub_folder = file_path
622
+ n1, n2 = info_folder(sub_folder, num_files, num_hebrew_letters)
623
+ num_files += n1
624
+ num_hebrew_letters += n2
625
+ return num_files, num_hebrew_letters
626
+
627
+
628
+ def extract_text_to_compare_nakdimon(text):
629
+ res = text.replace("|", "")
630
+ res = res.replace(
631
+ chr(Nikud.nikud_dict["KUBUTZ"]) + "ื•" + chr(Nikud.nikud_dict["METEG"]),
632
+ "ื•" + chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
633
+ )
634
+ res = res.replace(
635
+ chr(Nikud.nikud_dict["HOLAM"]) + "ื•" + chr(Nikud.nikud_dict["METEG"]), "ื•"
636
+ )
637
+ res = res.replace(
638
+ "ื•" + chr(Nikud.nikud_dict["HOLAM"]) + chr(Nikud.nikud_dict["KAMATZ"]),
639
+ "ื•" + chr(Nikud.nikud_dict["KAMATZ"]),
640
+ )
641
+ res = res.replace(chr(Nikud.nikud_dict["METEG"]), "")
642
+ res = res.replace(
643
+ chr(Nikud.nikud_dict["KAMATZ"]) + chr(Nikud.nikud_dict["HIRIK"]),
644
+ chr(Nikud.nikud_dict["KAMATZ"]) + "ื™" + chr(Nikud.nikud_dict["HIRIK"]),
645
+ )
646
+ res = res.replace(
647
+ chr(Nikud.nikud_dict["PATAKH"]) + chr(Nikud.nikud_dict["HIRIK"]),
648
+ chr(Nikud.nikud_dict["PATAKH"]) + "ื™" + chr(Nikud.nikud_dict["HIRIK"]),
649
+ )
650
+ res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION MAQAF"]), "")
651
+ res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION PASEQ"]), "")
652
+ res = res.replace(
653
+ chr(Nikud.nikud_dict["KAMATZ_KATAN"]), chr(Nikud.nikud_dict["KAMATZ"])
654
+ )
655
+
656
+ res = re.sub(chr(Nikud.nikud_dict["KUBUTZ"]) + "ื•" + "(?=[ื-ืช])", "ื•", res)
657
+ res = res.replace(chr(Nikud.nikud_dict["REDUCED_KAMATZ"]) + "ื•", "ื•")
658
+
659
+ res = res.replace(
660
+ chr(Nikud.nikud_dict["DAGESH OR SHURUK"]) * 2,
661
+ chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
662
+ )
663
+ res = res.replace("\u05be", "-")
664
+ res = res.replace("ื™ึฐื”ื•ึนึธื”", "ื™ื”ื•ื”")
665
+
666
+ return res
667
+
668
+
669
+ def orgenize_data(main_folder, logger):
670
+ x = NikudDataset(None)
671
+ x.delete_files(os.path.join(Path(main_folder).parent, "train"))
672
+ x.delete_files(os.path.join(Path(main_folder).parent, "dev"))
673
+ x.delete_files(os.path.join(Path(main_folder).parent, "test"))
674
+ x.split_data(
675
+ main_folder, main_folder_name=os.path.basename(main_folder), logger=logger
676
+ )