bkhmsi commited on
Commit
d7c4b94
1 Parent(s): 75c487d

support for TD2

Browse files
Files changed (9) hide show
  1. .gitignore +2 -1
  2. PartialDD.png +0 -0
  3. app.py +45 -14
  4. config.yaml +27 -2
  5. dataloader.py +1 -1
  6. dataloader_plm.py +126 -0
  7. model_partial.py +10 -6
  8. model_plm.py +360 -0
  9. predict.py +16 -7
.gitignore CHANGED
@@ -2,4 +2,5 @@
2
  *.pt
3
  *.vec
4
  *.pem
5
- .DS_Store
 
 
2
  *.pt
3
  *.vec
4
  *.pem
5
+ .DS_Store
6
+ gradio_cached_examples
PartialDD.png ADDED
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  import yaml
3
  import gdown
4
- import time
5
  import gradio as gr
6
  from predict import PredictTri
7
- from gradio import blocks
8
 
9
  output_path = "tashkeela-d2.pt"
10
  gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
@@ -12,13 +11,14 @@ if not os.path.exists(output_path):
12
  model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
13
  gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
14
 
15
- time.sleep(1)
16
-
17
  output_path = "vocab.vec"
18
  if not os.path.exists(output_path):
19
  vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
20
  gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
21
 
 
 
 
22
  with open("config.yaml", 'r', encoding="utf-8") as file:
23
  config = yaml.load(file, Loader=yaml.FullLoader)
24
 
@@ -27,16 +27,31 @@ config["train"]["max-token-count"] = config["predictor"]["window"] * 3
27
 
28
  predictor = PredictTri(config)
29
 
30
- def diacritze_full(text):
 
 
 
 
 
 
 
 
 
31
  do_hard_mask = None
32
  threshold = None
33
- predictor.create_dataloader(text, False, do_hard_mask, threshold)
34
  diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
35
  return diacritized_lines
36
 
37
- def diacritze_partial(text, mask_mode, threshold):
 
 
 
 
 
 
38
  do_partial = True
39
- predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold)
40
  diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
41
  return diacritized_lines
42
 
@@ -45,9 +60,19 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
45
  """
46
  # Partial Diacritization: A Context-Contrastive Inference Approach
47
  ### Authors: Muhammad ElNokrashy, Badr AlKhamissi
48
- ### Paper Link: TBD
49
  """)
50
 
 
 
 
 
 
 
 
 
 
 
51
  with gr.Tab(label="Full Diacritization"):
52
 
53
  full_input_txt = gr.Textbox(
@@ -69,11 +94,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
69
  )
70
 
71
  full_btn = gr.Button(value="Shakkel")
72
- full_btn.click(diacritze_full, inputs=[full_input_txt], outputs=[full_output_txt])
73
 
74
  gr.Examples(
75
  examples=[
76
- "ولو حمل من مجلس الخيار ، ولم يمنع من الكلام"
77
  ],
78
  inputs=full_input_txt,
79
  outputs=full_output_txt,
@@ -105,11 +130,13 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
105
  )
106
 
107
  partial_btn = gr.Button(value="Shakkel")
108
- partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider], outputs=[partial_output_txt])
109
 
110
  gr.Examples(
111
  examples=[
112
- ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0],
 
 
113
  ],
114
  inputs=[partial_input_txt, masking_mode, threshold_slider],
115
  outputs=partial_output_txt,
@@ -117,7 +144,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
117
  cache_examples=True,
118
  )
119
 
120
-
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
  demo.queue().launch(
 
1
  import os
2
  import yaml
3
  import gdown
 
4
  import gradio as gr
5
  from predict import PredictTri
6
+ from huggingface_hub import hf_hub_download
7
 
8
  output_path = "tashkeela-d2.pt"
9
  gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
 
11
  model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
12
  gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
13
 
 
 
14
  output_path = "vocab.vec"
15
  if not os.path.exists(output_path):
16
  vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
17
  gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
18
 
19
+ if not os.path.exists("td2/tashkeela-ashaar-td2.pt"):
20
+ hf_hub_download(repo_id="munael/Partial-Arabic-Diacritization-TD2", filename="tashkeela-ashaar-td2.pt", local_dir="td2")
21
+
22
  with open("config.yaml", 'r', encoding="utf-8") as file:
23
  config = yaml.load(file, Loader=yaml.FullLoader)
24
 
 
27
 
28
  predictor = PredictTri(config)
29
 
30
+ current_model_name = "TD2"
31
+ config["model-name"] = current_model_name
32
+
33
+ def diacritze_full(text, model_name):
34
+ global current_model_name, predictor
35
+ if model_name != current_model_name:
36
+ config["model-name"] = model_name
37
+ current_model_name = model_name
38
+ predictor = PredictTri(config)
39
+
40
  do_hard_mask = None
41
  threshold = None
42
+ predictor.create_dataloader(text, False, do_hard_mask, threshold, model_name)
43
  diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
44
  return diacritized_lines
45
 
46
+ def diacritze_partial(text, mask_mode, threshold, model_name):
47
+ global current_model_name, predictor
48
+ if model_name != current_model_name:
49
+ config["model-name"] = model_name
50
+ current_model_name = model_name
51
+ predictor = PredictTri(config)
52
+
53
  do_partial = True
54
+ predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold, model_name)
55
  diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
56
  return diacritized_lines
57
 
 
60
  """
61
  # Partial Diacritization: A Context-Contrastive Inference Approach
62
  ### Authors: Muhammad ElNokrashy, Badr AlKhamissi
63
+ ### Paper Link: TBD (abstract below)
64
  """)
65
 
66
+ gr.HTML(
67
+ "<img src='./PartialDD.png' style='float:right'/>"
68
+ )
69
+
70
+ model_choice = gr.Dropdown(
71
+ choices=["D2", "TD2"],
72
+ label="Diacritization Model",
73
+ value=current_model_name
74
+ )
75
+
76
  with gr.Tab(label="Full Diacritization"):
77
 
78
  full_input_txt = gr.Textbox(
 
94
  )
95
 
96
  full_btn = gr.Button(value="Shakkel")
97
+ full_btn.click(diacritze_full, inputs=[full_input_txt, model_choice], outputs=[full_output_txt])
98
 
99
  gr.Examples(
100
  examples=[
101
+ "ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "TD2"
102
  ],
103
  inputs=full_input_txt,
104
  outputs=full_output_txt,
 
130
  )
131
 
132
  partial_btn = gr.Button(value="Shakkel")
133
+ partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider, model_choice], outputs=[partial_output_txt])
134
 
135
  gr.Examples(
136
  examples=[
137
+ ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0, "TD2"],
138
+ ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Soft", 0.1, "TD2"],
139
+ ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Soft", 0.01, "TD2"],
140
  ],
141
  inputs=[partial_input_txt, masking_mode, threshold_slider],
142
  outputs=partial_output_txt,
 
144
  cache_examples=True,
145
  )
146
 
147
+ gr.Markdown(
148
+ """
149
+ ### Abstract
150
+ > Diacritization plays a pivotal role in improving readability and disambiguating the meaning of Arabic texts. Efforts have so far focused on marking every eligible character (Full Diacritization). Comparatively overlooked, Partial Diacritzation (PD) is the selection of a subset of characters to be marked to aid comprehension where needed.Research has indicated that excessive diacritic marks can hinder skilled readers---reducing reading speed and accuracy. We conduct a behavioral experiment and show that partially marked text is often easier to read than fully marked text, and sometimes easier than plain text. In this light, we introduce Context-Contrastive Partial Diacritization (CCPD)---a novel approach to PD which integrates seamlessly with existing Arabic diacritization systems. CCPD processes each word twice, once with context and once without, and diacritizes only the characters with disparities between the two inferences. Further, we introduce novel indicators for measuring partial diacritization quality {SR, PDER, HDER, ERE}, essential for establishing this as a machine learning task. Lastly, we introduce TD2, a Transformer-variant of an established model which offers a markedly different performance profile on our proposed indicators compared to all other known systems.
151
+ """)
152
 
153
  if __name__ == "__main__":
154
  demo.queue().launch(
config.yaml CHANGED
@@ -1,22 +1,47 @@
1
  run-title: tashkeela-d2
2
  debug: false
 
3
 
4
  paths:
5
  base: ./dataset/ashaar
6
  save: ./models
7
  load: tashkeela-d2.pt
 
8
  resume: ./models/Tashkeela-D2/tashkeela-d2.pt
9
  constants: ./dataset/helpers/constants
10
  word-embs: vocab.vec
11
  test: test
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  loader:
14
  wembs-limit: -1
15
  num-workers: 0
16
 
17
  train:
18
  epochs: 1000
19
- batch-size: 32
20
  char-embed-dim: 32
21
  resume: false
22
  resume-lr: false
@@ -51,7 +76,7 @@ train:
51
  stopping-patience: 3
52
 
53
  predictor:
54
- batch-size: 75
55
  stride: 2
56
  window: 20
57
  gt-signal-prob: 0
 
1
  run-title: tashkeela-d2
2
  debug: false
3
+ model-name: TD2
4
 
5
  paths:
6
  base: ./dataset/ashaar
7
  save: ./models
8
  load: tashkeela-d2.pt
9
+ load-td2: td2/tashkeela-ashaar-td2.pt
10
  resume: ./models/Tashkeela-D2/tashkeela-d2.pt
11
  constants: ./dataset/helpers/constants
12
  word-embs: vocab.vec
13
  test: test
14
 
15
+ modeling:
16
+ "checkpoint": munael/Partial-Arabic-Diacritization-TD2
17
+ "base_model": CAMeL-Lab/bert-base-arabic-camelbert-mix-ner
18
+ # "base_model": UBC-NLP/MARBERTv2
19
+ # "base_model": UBC-NLP/ARBERTv2
20
+ "deep-cls": true
21
+ "full-finetune": true #< From true
22
+ "keep-token-model-layers": 2
23
+ # "num-finetune-last-layers": 2 #
24
+ "num-chars": 40
25
+ "char-embed-dim": 128
26
+ "token_hidden_size": 768
27
+ "deep-down-proj": true
28
+ "dropout": 0.2
29
+ "sentence_dropout": 0.1
30
+ "diac_model_config": {
31
+ "vocab_size": 1,
32
+ "num_hidden_layers": 2,
33
+ "hidden_size": 768,
34
+ "intermediate_size": 2304,
35
+ "num_attention_heads": 8,
36
+ }
37
+
38
  loader:
39
  wembs-limit: -1
40
  num-workers: 0
41
 
42
  train:
43
  epochs: 1000
44
+ batch-size: 1
45
  char-embed-dim: 32
46
  resume: false
47
  resume-lr: false
 
76
  stopping-patience: 3
77
 
78
  predictor:
79
+ batch-size: 1
80
  stride: 2
81
  window: 20
82
  gt-signal-prob: 0
dataloader.py CHANGED
@@ -24,7 +24,7 @@ class DataRetriever(Dataset):
24
 
25
  def __getitem__(self, idx):
26
  word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
27
- return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long)
28
 
29
  def create_sentence(self, idx):
30
  line = self.lines[idx]
 
24
 
25
  def __getitem__(self, idx):
26
  word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
27
+ return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long), [0]
28
 
29
  def create_sentence(self, idx):
30
  line = self.lines[idx]
dataloader_plm.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Any
2
+
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ from pyarabic.araby import tokenize, strip_tashkeel
7
+
8
+ import numpy as np
9
+ import torch as T
10
+ from torch.utils.data import Dataset
11
+
12
+ try:
13
+ from transformers import PreTrainedTokenizer
14
+ except:
15
+ from typing import Any as PreTrainedTokenizer
16
+
17
+ from data_utils import DatasetUtils
18
+ import diac_utils as du
19
+
20
+ class DataRetriever(Dataset):
21
+ def __init__(
22
+ self,
23
+ lines,
24
+ data_utils: DatasetUtils,
25
+ is_test: bool = False,
26
+ *,
27
+ tokenizer: PreTrainedTokenizer,
28
+ lines_mode: bool = False,
29
+ **kwargs,
30
+ ):
31
+ super(DataRetriever).__init__()
32
+
33
+ self.data_utils = data_utils
34
+ self.is_test = is_test
35
+ self.tokenizer = tokenizer
36
+
37
+ self.stride = data_utils.test_stride
38
+
39
+ self.data_points = lines
40
+
41
+ self.bos_token_id = int(self.tokenizer.bos_token_id or self.tokenizer.cls_token_id)
42
+ self.eos_token_id = int(self.tokenizer.eos_token_id or self.tokenizer.sep_token_id)
43
+
44
+ self.max_tokens = self.data_utils.max_token_count
45
+ self.max_slen = self.data_utils.max_sent_len
46
+ self.max_wlen = self.data_utils.max_word_len
47
+ # self.p_val = self.data_utils.pad_val
48
+ self.p_val = self.tokenizer.pad_token_id
49
+ self.pc_val = self.data_utils.pad_char_id
50
+ self.pt_val = self.data_utils.pad_target_val
51
+
52
+ self.char_x_padding = [self.pc_val] * self.max_wlen
53
+ self.diac_x_padding = [[self.pc_val]*8] * self.max_wlen
54
+ self.diac_y_padding = [self.pt_val] * self.max_wlen
55
+
56
+ def preprocess(self, data, dtype=T.long):
57
+ return [T.tensor(np.array(x), dtype=dtype) for x in data]
58
+
59
+ def __len__(self):
60
+ return len(self.data_points)
61
+
62
+ @lru_cache(maxsize=1024 * 2)
63
+ def __getitem__(self, idx: int) -> Tuple[List[T.Tensor], T.Tensor, T.Tensor]:
64
+ word_x, char_x, diac_x, diac_y, subword_lengths = self.create_sentence(idx)
65
+ return (
66
+ self.preprocess([word_x, char_x, diac_x]),
67
+ T.tensor(diac_y, dtype=T.long),
68
+ T.tensor(subword_lengths, dtype=T.long)
69
+ )
70
+
71
+ def create_sentence(self, idx):
72
+ line = self.data_points[idx]
73
+ # tokens = tokenize(line.strip())
74
+ words: List[str] = tokenize(line.strip())
75
+ # words_: List[str] = []
76
+ # for word in words:
77
+ # if len(strip_tashkeel(word)) == 0:
78
+ # words_[-1] += word.strip()
79
+ # else:
80
+ # words_.append(word)
81
+ # word_tokens_bin = [self.tokenizer(word) for word in words]
82
+ # tokens_bin = self.tokenizer(line.strip())
83
+
84
+ subwords_x = [self.bos_token_id]
85
+ subword_lengths = []
86
+
87
+ char_x = []
88
+ diac_x = []
89
+ diac_y = []
90
+ diac_y_tmp = []
91
+
92
+ for i_word, word in enumerate(words):
93
+ word = du.strip_unknown_tashkeel(word)
94
+ word_chars = du.split_word_on_characters_with_diacritics(word)
95
+ cx, cy, cy_3head = du.create_label_for_word(word_chars)
96
+
97
+ word_strip = strip_tashkeel(word)
98
+ #? List[int: "word_index"]
99
+ #? Strip the BOS/EOS which the tokenizer adds
100
+ word_sub_ids = self.tokenizer(word_strip)['input_ids'][1:-1]
101
+ subword_lengths += [len(word_sub_ids)]
102
+
103
+ subwords_x += word_sub_ids
104
+ # word_x += [self.data_utils.w2idx.get(word_strip, self.data_utils.w2idx["<pad>"])]
105
+
106
+ char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.max_wlen)]
107
+
108
+ diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.max_wlen, pad=self.data_utils.pad_target_val)]
109
+ diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.max_wlen, pad=[self.data_utils.pad_target_val]*3)]
110
+
111
+ assert len(char_x) == len(subword_lengths), f"{char_x=}; {subword_lengths=} ;;"
112
+ assert len(char_x) == len(words)
113
+
114
+ diac_x = self.data_utils.create_decoder_input(diac_y_tmp)
115
+
116
+ subwords_x += [self.eos_token_id]
117
+ # assert len(char_x) + 2 == len(subwords_x), f"{len(char_x)} + 2 != {len(subwords_x)} ;;" # Because of BOS, EOS
118
+ assert len(subword_lengths) == len(words)
119
+ subwords_x = self.data_utils.pad_and_truncate_sequence(subwords_x, self.max_tokens, pad=self.p_val)
120
+ subword_lengths = self.data_utils.pad_and_truncate_sequence(subword_lengths, self.max_slen, pad=0)
121
+
122
+ char_x = self.data_utils.pad_and_truncate_sequence(char_x, self.max_slen, pad=self.char_x_padding)
123
+ diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, self.max_slen, pad=self.diac_x_padding)
124
+ diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, self.max_slen, pad=self.diac_y_padding)
125
+
126
+ return subwords_x, char_x, diac_x, diac_y, subword_lengths
model_partial.py CHANGED
@@ -9,7 +9,7 @@ from torch.nn import functional as F
9
  from diac_utils import flat_2_3head
10
 
11
  from model_dd import DiacritizerD2
12
- from model_dd import DatasetUtils
13
 
14
  class Readout(nn.Module):
15
  def __init__(
@@ -72,8 +72,11 @@ class PartialDD(nn.Module):
72
  # self.config_d2 = yaml.safe_load(fin)
73
  # self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
74
  self.config = config
75
- self._use_d2 = True
76
- self.sentence_diac = DiacritizerD2(self.config)
 
 
 
77
 
78
  # self.sentence_diac.to(self.device)
79
  # self.build()
@@ -90,9 +93,10 @@ class PartialDD(nn.Module):
90
 
91
  def load_state_dict(
92
  self,
93
- state_dict: dict
 
94
  ):
95
- self.sentence_diac.load_state_dict(state_dict)
96
 
97
  def _slim_batch(
98
  self,
@@ -277,7 +281,7 @@ class PartialDD(nn.Module):
277
  }
278
  print("> Predicting...")
279
  # breakpoint()
280
- for i_batch, (inputs, _) in enumerate(tqdm(dataloader)):
281
  # if i_batch > 10:
282
  # break
283
  #^ inputs: [toke_ids, char_ids, diac_ids]
 
9
  from diac_utils import flat_2_3head
10
 
11
  from model_dd import DiacritizerD2
12
+ from model_plm import Diacritizer
13
 
14
  class Readout(nn.Module):
15
  def __init__(
 
72
  # self.config_d2 = yaml.safe_load(fin)
73
  # self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
74
  self.config = config
75
+ self._use_d2 = config["model-name"] == "D2"
76
+ if self._use_d2:
77
+ self.sentence_diac = DiacritizerD2(self.config)
78
+ else:
79
+ self.sentence_diac = Diacritizer(self.config, load_pretrained=False)
80
 
81
  # self.sentence_diac.to(self.device)
82
  # self.build()
 
93
 
94
  def load_state_dict(
95
  self,
96
+ state_dict: dict,
97
+ strict: bool = True,
98
  ):
99
+ self.sentence_diac.load_state_dict(state_dict, strict=strict)
100
 
101
  def _slim_batch(
102
  self,
 
281
  }
282
  print("> Predicting...")
283
  # breakpoint()
284
+ for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)):
285
  # if i_batch > 10:
286
  # break
287
  #^ inputs: [toke_ids, char_ids, diac_ids]
model_plm.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterator, cast
2
+
3
+ import copy
4
+ import numpy as np
5
+
6
+ import torch as T
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from transformers import BertConfig, BertModel
10
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
13
+
14
+ class Diacritizer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ config,
18
+ device=None,
19
+ load_pretrained=True
20
+ ) -> None:
21
+ super().__init__()
22
+ self._dummy = nn.Parameter(T.ones(1))
23
+
24
+ if 'modeling' in config:
25
+ config = config['modeling']
26
+ self.config = config
27
+
28
+ model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner")
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+
31
+ if load_pretrained:
32
+ self.token_model: BertModel = AutoModel.from_pretrained(model_name)
33
+ else:
34
+ marbert_config = AutoConfig.from_pretrained(model_name)
35
+ self.token_model = AutoModel.from_config(marbert_config)
36
+
37
+ self.num_classes = 15
38
+ self.diac_model_config = BertConfig(**config['diac_model_config'])
39
+ self.token_model_config: BertConfig = self.token_model.config
40
+
41
+ self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"])
42
+ self.diac_emb_model = self.build_diac_model(self.token_model)
43
+
44
+ self.down_project_token_embeds_deep = None
45
+ self.down_project_token_embeds = None
46
+ if 'token_hidden_size' in config:
47
+ if config['token_hidden_size'] == 'auto':
48
+ down_proj_size = self.diac_emb_model.config.hidden_size
49
+ else:
50
+ down_proj_size = config['token_hidden_size']
51
+ if config.get('deep-down-proj', False):
52
+ self.down_project_token_embeds_deep = nn.Sequential(
53
+ nn.Linear(
54
+ self.token_model_config.hidden_size + config["char-embed-dim"],
55
+ down_proj_size * 4,
56
+ bias=False,
57
+ ),
58
+ nn.Tanh(),
59
+ nn.Linear(
60
+ down_proj_size * 4,
61
+ down_proj_size,
62
+ bias=False,
63
+ )
64
+ )
65
+ # else:
66
+ self.down_project_token_embeds = nn.Linear(
67
+ self.token_model_config.hidden_size + config["char-embed-dim"],
68
+ down_proj_size,
69
+ bias=False,
70
+ )
71
+
72
+ # assert self.down_project_token_embeds_deep is None or self.down_project_token_embeds is None
73
+ classifier_feature_size = self.diac_model_config.hidden_size
74
+ if config.get('deep-cls', False):
75
+ # classifier_feature_size = 512
76
+ self.final_feature_transform = nn.Linear(
77
+ self.diac_model_config.hidden_size
78
+ + self.token_model_config.hidden_size,
79
+ #^ diac_features + [residual from token_model]
80
+ out_features=classifier_feature_size,
81
+ bias=False
82
+ )
83
+ else:
84
+ self.final_feature_transform = None
85
+
86
+ self.feature_layer_norm = nn.LayerNorm(classifier_feature_size)
87
+ self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True)
88
+
89
+ self.trim_model_(config)
90
+
91
+ self.dropout = nn.Dropout(config['dropout'])
92
+ self.sent_dropout_p = config['sentence_dropout']
93
+ self.closs = F.cross_entropy
94
+
95
+ def build_diac_model(self, token_model=None):
96
+ if self.config.get('pre-init-diac-model', False):
97
+ model = copy.deepcopy(self.token_model)
98
+ model.pooler = None
99
+ model.embeddings.word_embeddings = None
100
+
101
+ num_layers = self.config.get('keep-token-model-layers', None)
102
+ model.encoder.layer = nn.ModuleList(
103
+ list(model.encoder.layer[num_layers:num_layers*2])
104
+ )
105
+
106
+ model.encoder.config.num_hidden_layers = num_layers
107
+ else:
108
+ model = BertModel(self.diac_model_config)
109
+ return model
110
+
111
+ def trim_model_(self, config):
112
+ self.token_model.pooler = None
113
+ self.diac_emb_model.pooler = None
114
+ # self.diac_emb_model.embeddings = None
115
+ self.diac_emb_model.embeddings.word_embeddings = None
116
+
117
+ num_token_model_kept_layers = config.get('keep-token-model-layers', None)
118
+ if num_token_model_kept_layers is not None:
119
+ self.token_model.encoder.layer = nn.ModuleList(
120
+ list(self.token_model.encoder.layer[:num_token_model_kept_layers])
121
+ )
122
+ self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers
123
+
124
+ if not config.get('full-finetune', False):
125
+ for param in self.token_model.parameters():
126
+ param.requires_grad = False
127
+ finetune_last_layers = config.get('num-finetune-last-layers', 4)
128
+ if finetune_last_layers > 0:
129
+ unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:]
130
+ for layer in unfrozen_layers:
131
+ for param in layer.parameters():
132
+ param.requires_grad = True
133
+
134
+ def get_grouped_params(self):
135
+ downstream_params: Iterator[nn.Parameter] = cast(
136
+ Iterator,
137
+ (param
138
+ for module in (self.diac_emb_model, self.classifier, self.char_embs)
139
+ for param in module.parameters())
140
+ )
141
+ pg = {
142
+ 'pretrained': self.token_model.parameters(),
143
+ 'downstream': downstream_params,
144
+ }
145
+ return pg
146
+
147
+ @property
148
+ def device(self):
149
+ return self._dummy.device
150
+
151
+ def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None):
152
+ # ^ word_x, char_x, diac_x are Indices
153
+ # ^ xt : self.preprocess((word_x, char_x, diac_x)),
154
+ # ^ yt : T.tensor(diac_y, dtype=T.long),
155
+ # ^ subword_lengths: T.tensor(subword_lengths, dtype=T.long)
156
+ #< Move char_x, diac_x to device because they're small and trainable
157
+ xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths)
158
+ xt[0] = xt[0].to(self.device)
159
+ xt[1] = xt[1].to(self.device)
160
+ # xt[2] = xt[2].to(self.device)
161
+
162
+ yt = yt.to(self.device)
163
+ #^ yt: [b tw tc]
164
+
165
+ Nb, Tword, Tchar = xt[1].shape
166
+ if Tword * Tchar < 500:
167
+ diac = self(*xt, subword_lengths)
168
+ loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
169
+ else:
170
+ num_chunks = Tword * Tchar / 300
171
+ loss = 0
172
+ for i in range(round(num_chunks+0.5)):
173
+ _slice = slice(i*300, (i+1)*300)
174
+ chunk = self._slice_batch(xt, _slice)
175
+ diac = self(*chunk, subword_lengths[_slice])
176
+ chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
177
+ loss = loss + chunk_loss
178
+
179
+ return loss
180
+
181
+ def _slice_batch(self, xt: List[T.Tensor], _slice):
182
+ return [xt[0][_slice], xt[1][_slice], xt[2][_slice]]
183
+
184
+ def _slim_batch_size(
185
+ self,
186
+ tx: T.Tensor,
187
+ cx: T.Tensor,
188
+ yt: T.Tensor,
189
+ subword_lengths: T.Tensor
190
+ ):
191
+ #^ tx : [b tt]
192
+ #^ cx : [b tw tc]
193
+ #^ yt : [b tw tc]
194
+ token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id)
195
+ Ttoken = token_nonpad_mask.sum(1).max()
196
+ tx = tx[:, :Ttoken]
197
+
198
+ char_nonpad_mask = cx.ne(0)
199
+ Tword = char_nonpad_mask.any(2).sum(1).max()
200
+ Tchar = char_nonpad_mask.sum(2).max()
201
+ cx = cx[:, :Tword, :Tchar]
202
+ yt = yt[:, :Tword, :Tchar]
203
+ subword_lengths = subword_lengths[:, :Tword]
204
+
205
+ return tx, cx, yt, subword_lengths
206
+
207
+ def token_dropout(self, toke_x):
208
+ #^ toke_x : [b tw]
209
+ if self.training:
210
+ q = 1.0 - self.sent_dropout_p
211
+ sdo = T.bernoulli(T.full(toke_x.shape, q))
212
+ toke_x[sdo == 0] = self.tokenizer.pad_token_id
213
+ return toke_x
214
+
215
+ def sentence_dropout(self, word_embs: T.Tensor):
216
+ #^ word_embs : [b tw dwe]
217
+ if self.training:
218
+ q = 1.0 - self.sent_dropout_p
219
+ sdo = T.bernoulli(T.full(word_embs.shape[:2], q))
220
+ sdo = sdo.detach().unsqueeze(-1).to(word_embs)
221
+ word_embs = word_embs * sdo
222
+ # toke_x[sdo == 0] = self.tokenizer.pad_token_id
223
+ return word_embs
224
+
225
+ def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor):
226
+ y: BaseModelOutputWithPoolingAndCrossAttentions
227
+ y = self.token_model(input_ids, attention_mask=attention_mask)
228
+ z = y.last_hidden_state
229
+ return z
230
+
231
+ def forward(
232
+ self,
233
+ toke_x : T.Tensor,
234
+ char_x : T.Tensor,
235
+ diac_x : T.Tensor,
236
+ subword_lengths : T.Tensor,
237
+ ):
238
+ #^ toke_x : [b tt]
239
+ #^ char_x : [b tw tc]
240
+ #^ diac_x/labels : [b tw tc]
241
+ #^ subword_lengths : [b, tw]
242
+ # !TODO Use `subword_lengths` to aggregate subword embeddings first before ...
243
+ # ... passing concatenated contextual embedding to chars in diac_model
244
+
245
+ token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id)
246
+ char_nonpad_mask = char_x.ne(0)
247
+
248
+ Nb, Tw, Tc = char_x.shape
249
+ # assert Tw == Tw_0 and Tc == Tc_0, f"{Tw=} {Tw_0=}, {Tc=} {Tc_0=}"
250
+
251
+ # toke_x = self.token_dropout(toke_x)
252
+ token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask)
253
+ # token_embs = self.sentence_dropout(token_embs)
254
+ #? Strip BOS,EOS
255
+ token_embs = token_embs[:, 1:-1, ...]
256
+
257
+ sent_word_strides = subword_lengths.cumsum(1)
258
+ sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs)
259
+ for i_b in range(Nb):
260
+ token_embs_ib = token_embs[i_b]
261
+ start_iw = 0
262
+ for i_word, end_iw in enumerate(sent_word_strides[i_b]):
263
+ if end_iw == start_iw: break
264
+ word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw)
265
+ sent_enc[i_b, i_word] = word_emb
266
+ start_iw = end_iw
267
+ #^ sent_enc: [b tw dwe]
268
+
269
+ char_x_flat = char_x.reshape(Nb*Tw, Tc)
270
+ char_nonpad_mask = char_x_flat.gt(0)
271
+ # ^ char_nonpad_mask [b*tw tc]
272
+
273
+ char_x_flat = char_x_flat * char_nonpad_mask
274
+
275
+ cembs = self.char_embs(char_x_flat)
276
+
277
+ #^ cembs: [b*tw tc dce]
278
+ wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1)
279
+ #^ wembs: [b tw dwe] => [b tw _ dwe] => [b*tw tc dwe]
280
+ cw_embs = T.cat([cembs, wembs], dim=-1)
281
+ #^ char_embs : [b*tw tc dcw] ; dcw = dc + dwe
282
+ cw_embs = self.dropout(cw_embs)
283
+
284
+ cw_embs_ = cw_embs
285
+ if self.down_project_token_embeds is not None:
286
+ cw_embs_ = self.down_project_token_embeds(cw_embs)
287
+ if self.down_project_token_embeds_deep is not None:
288
+ cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs)
289
+ cw_embs = cw_embs_
290
+
291
+ diac_enc: BaseModelOutputWithPoolingAndCrossAttentions
292
+ diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask)
293
+ diac_emb = diac_enc.last_hidden_state
294
+ diac_emb = self.dropout(diac_emb)
295
+ #^ diac_emb: [b*tw tc dce]
296
+ diac_emb = diac_emb.view(Nb, Tw, Tc, -1)
297
+
298
+ sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1)
299
+ final_feature = T.cat([sent_residual, diac_emb], dim=-1)
300
+ if self.final_feature_transform is not None:
301
+ final_feature = self.final_feature_transform(final_feature)
302
+ final_feature = F.tanh(final_feature)
303
+ final_feature = self.dropout(final_feature)
304
+ else:
305
+ final_feature = diac_emb
306
+
307
+ # final_feature = self.feature_layer_norm(final_feature)
308
+ diac_out = self.classifier(final_feature)
309
+ # if T.isnan(diac_out).any():
310
+ # breakpoint()
311
+ return diac_out
312
+
313
+ def predict(self, dataloader):
314
+ from tqdm import tqdm
315
+ import diac_utils as du
316
+ training = self.training
317
+ self.eval()
318
+
319
+ preds = {'haraka': [], 'shadda': [], 'tanween': []}
320
+ print("> Predicting...")
321
+ for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)):
322
+ inputs[0] = inputs[0].to(self.device)
323
+ inputs[1] = inputs[1].to(self.device)
324
+ output = self(*inputs, subword_lengths).detach()
325
+
326
+ marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1)
327
+ #^ [b ts tw]
328
+
329
+ haraka, tanween, shadda = du.flat_2_3head(marks)
330
+
331
+ preds['haraka'].extend(haraka)
332
+ preds['tanween'].extend(tanween)
333
+ preds['shadda'].extend(shadda)
334
+
335
+ self.train(training)
336
+ return (
337
+ np.array(preds['haraka']),
338
+ np.array(preds["tanween"]),
339
+ np.array(preds["shadda"]),
340
+ )
341
+
342
+ if __name__ == "__main__":
343
+ model = Diacritizer({
344
+ "num-chars": 36,
345
+ "hidden_size": 768,
346
+ "char-embed-dim": 32,
347
+ "dropout": 0.25,
348
+ "sentence_dropout": 0.2,
349
+ "diac_model_config": {
350
+ "num_layers": 4,
351
+ "hidden_size": 768 + 32,
352
+ "intermediate_size": (768 + 32) * 4,
353
+ },
354
+ }, load_pretrained=False)
355
+
356
+ total_params = sum(p.numel() for p in model.parameters())
357
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
358
+
359
+ print(model)
360
+ print(f"{trainable_params:,}/{total_params:,} Trainable Parameters")
predict.py CHANGED
@@ -14,9 +14,9 @@ from torch.utils.data import DataLoader
14
 
15
  from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
16
  from model_partial import PartialDD
17
- from model_dd import DiacritizerD2
18
  from data_utils import DatasetUtils
19
  from dataloader import DataRetriever
 
20
  from segment import segment
21
 
22
  from partial_dd_metrics import (
@@ -105,15 +105,19 @@ class Predictor:
105
  config['predictor'].get('device', 'cuda:0')
106
  if T.cuda.is_available() else 'cpu'
107
  )
108
-
109
  self.model = PartialDD(config)
110
- self.model.sentence_diac.build(word_embeddings, vocab_size)
111
- state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
112
- self.model.load_state_dict(state_dict)
 
 
 
 
113
  self.model.to(self.device)
114
  self.model.eval()
115
 
116
- def create_dataloader(self, text, do_partial, do_hard_mask, threshold):
117
  self.threshold = threshold
118
  self.do_hard_mask = do_hard_mask
119
 
@@ -137,7 +141,12 @@ class Predictor:
137
  self.original_lines = text.split('\n')
138
 
139
  self.data_loader = DataLoader(
140
- DataRetriever(self.data_utils, segments),
 
 
 
 
 
141
  batch_size=self.config["predictor"].get("batch-size", 32),
142
  shuffle=False,
143
  num_workers=self.config['loader'].get('num-workers', 0),
 
14
 
15
  from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
16
  from model_partial import PartialDD
 
17
  from data_utils import DatasetUtils
18
  from dataloader import DataRetriever
19
+ from dataloader_plm import DataRetriever as DataRetrieverPLM
20
  from segment import segment
21
 
22
  from partial_dd_metrics import (
 
105
  config['predictor'].get('device', 'cuda:0')
106
  if T.cuda.is_available() else 'cpu'
107
  )
108
+
109
  self.model = PartialDD(config)
110
+ if config["model-name"] == "D2":
111
+ self.model.sentence_diac.build(word_embeddings, vocab_size)
112
+ state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
113
+ else:
114
+ state_dict = T.load(config["paths"]["load-td2"], map_location=T.device(self.device))['state_dict']
115
+
116
+ self.model.load_state_dict(state_dict, strict=False)
117
  self.model.to(self.device)
118
  self.model.eval()
119
 
120
+ def create_dataloader(self, text, do_partial, do_hard_mask, threshold, model_name):
121
  self.threshold = threshold
122
  self.do_hard_mask = do_hard_mask
123
 
 
141
  self.original_lines = text.split('\n')
142
 
143
  self.data_loader = DataLoader(
144
+ DataRetriever(self.data_utils, segments)
145
+ if model_name == "D2"
146
+ else DataRetrieverPLM(segments, self.data_utils,
147
+ is_test=True,
148
+ tokenizer=self.model.tokenizer
149
+ ),
150
  batch_size=self.config["predictor"].get("batch-size", 32),
151
  shuffle=False,
152
  num_workers=self.config['loader'].get('num-workers', 0),