support for TD2
Browse files- .gitignore +2 -1
- PartialDD.png +0 -0
- app.py +45 -14
- config.yaml +27 -2
- dataloader.py +1 -1
- dataloader_plm.py +126 -0
- model_partial.py +10 -6
- model_plm.py +360 -0
- predict.py +16 -7
.gitignore
CHANGED
@@ -2,4 +2,5 @@
|
|
2 |
*.pt
|
3 |
*.vec
|
4 |
*.pem
|
5 |
-
.DS_Store
|
|
|
|
2 |
*.pt
|
3 |
*.vec
|
4 |
*.pem
|
5 |
+
.DS_Store
|
6 |
+
gradio_cached_examples
|
PartialDD.png
ADDED
app.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import os
|
2 |
import yaml
|
3 |
import gdown
|
4 |
-
import time
|
5 |
import gradio as gr
|
6 |
from predict import PredictTri
|
7 |
-
from
|
8 |
|
9 |
output_path = "tashkeela-d2.pt"
|
10 |
gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
|
@@ -12,13 +11,14 @@ if not os.path.exists(output_path):
|
|
12 |
model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
|
13 |
gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
|
14 |
|
15 |
-
time.sleep(1)
|
16 |
-
|
17 |
output_path = "vocab.vec"
|
18 |
if not os.path.exists(output_path):
|
19 |
vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
|
20 |
gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
|
21 |
|
|
|
|
|
|
|
22 |
with open("config.yaml", 'r', encoding="utf-8") as file:
|
23 |
config = yaml.load(file, Loader=yaml.FullLoader)
|
24 |
|
@@ -27,16 +27,31 @@ config["train"]["max-token-count"] = config["predictor"]["window"] * 3
|
|
27 |
|
28 |
predictor = PredictTri(config)
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
do_hard_mask = None
|
32 |
threshold = None
|
33 |
-
predictor.create_dataloader(text, False, do_hard_mask, threshold)
|
34 |
diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
|
35 |
return diacritized_lines
|
36 |
|
37 |
-
def diacritze_partial(text, mask_mode, threshold):
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
do_partial = True
|
39 |
-
predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold)
|
40 |
diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
|
41 |
return diacritized_lines
|
42 |
|
@@ -45,9 +60,19 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
|
|
45 |
"""
|
46 |
# Partial Diacritization: A Context-Contrastive Inference Approach
|
47 |
### Authors: Muhammad ElNokrashy, Badr AlKhamissi
|
48 |
-
### Paper Link: TBD
|
49 |
""")
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
with gr.Tab(label="Full Diacritization"):
|
52 |
|
53 |
full_input_txt = gr.Textbox(
|
@@ -69,11 +94,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
|
|
69 |
)
|
70 |
|
71 |
full_btn = gr.Button(value="Shakkel")
|
72 |
-
full_btn.click(diacritze_full, inputs=[full_input_txt], outputs=[full_output_txt])
|
73 |
|
74 |
gr.Examples(
|
75 |
examples=[
|
76 |
-
"ولو حمل من مجلس الخيار ، ولم يمنع من الكلام"
|
77 |
],
|
78 |
inputs=full_input_txt,
|
79 |
outputs=full_output_txt,
|
@@ -105,11 +130,13 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
|
|
105 |
)
|
106 |
|
107 |
partial_btn = gr.Button(value="Shakkel")
|
108 |
-
partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider], outputs=[partial_output_txt])
|
109 |
|
110 |
gr.Examples(
|
111 |
examples=[
|
112 |
-
["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0],
|
|
|
|
|
113 |
],
|
114 |
inputs=[partial_input_txt, masking_mode, threshold_slider],
|
115 |
outputs=partial_output_txt,
|
@@ -117,7 +144,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
|
|
117 |
cache_examples=True,
|
118 |
)
|
119 |
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
demo.queue().launch(
|
|
|
1 |
import os
|
2 |
import yaml
|
3 |
import gdown
|
|
|
4 |
import gradio as gr
|
5 |
from predict import PredictTri
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
output_path = "tashkeela-d2.pt"
|
9 |
gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
|
|
|
11 |
model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
|
12 |
gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
|
13 |
|
|
|
|
|
14 |
output_path = "vocab.vec"
|
15 |
if not os.path.exists(output_path):
|
16 |
vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
|
17 |
gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
|
18 |
|
19 |
+
if not os.path.exists("td2/tashkeela-ashaar-td2.pt"):
|
20 |
+
hf_hub_download(repo_id="munael/Partial-Arabic-Diacritization-TD2", filename="tashkeela-ashaar-td2.pt", local_dir="td2")
|
21 |
+
|
22 |
with open("config.yaml", 'r', encoding="utf-8") as file:
|
23 |
config = yaml.load(file, Loader=yaml.FullLoader)
|
24 |
|
|
|
27 |
|
28 |
predictor = PredictTri(config)
|
29 |
|
30 |
+
current_model_name = "TD2"
|
31 |
+
config["model-name"] = current_model_name
|
32 |
+
|
33 |
+
def diacritze_full(text, model_name):
|
34 |
+
global current_model_name, predictor
|
35 |
+
if model_name != current_model_name:
|
36 |
+
config["model-name"] = model_name
|
37 |
+
current_model_name = model_name
|
38 |
+
predictor = PredictTri(config)
|
39 |
+
|
40 |
do_hard_mask = None
|
41 |
threshold = None
|
42 |
+
predictor.create_dataloader(text, False, do_hard_mask, threshold, model_name)
|
43 |
diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
|
44 |
return diacritized_lines
|
45 |
|
46 |
+
def diacritze_partial(text, mask_mode, threshold, model_name):
|
47 |
+
global current_model_name, predictor
|
48 |
+
if model_name != current_model_name:
|
49 |
+
config["model-name"] = model_name
|
50 |
+
current_model_name = model_name
|
51 |
+
predictor = PredictTri(config)
|
52 |
+
|
53 |
do_partial = True
|
54 |
+
predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold, model_name)
|
55 |
diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
|
56 |
return diacritized_lines
|
57 |
|
|
|
60 |
"""
|
61 |
# Partial Diacritization: A Context-Contrastive Inference Approach
|
62 |
### Authors: Muhammad ElNokrashy, Badr AlKhamissi
|
63 |
+
### Paper Link: TBD (abstract below)
|
64 |
""")
|
65 |
|
66 |
+
gr.HTML(
|
67 |
+
"<img src='./PartialDD.png' style='float:right'/>"
|
68 |
+
)
|
69 |
+
|
70 |
+
model_choice = gr.Dropdown(
|
71 |
+
choices=["D2", "TD2"],
|
72 |
+
label="Diacritization Model",
|
73 |
+
value=current_model_name
|
74 |
+
)
|
75 |
+
|
76 |
with gr.Tab(label="Full Diacritization"):
|
77 |
|
78 |
full_input_txt = gr.Textbox(
|
|
|
94 |
)
|
95 |
|
96 |
full_btn = gr.Button(value="Shakkel")
|
97 |
+
full_btn.click(diacritze_full, inputs=[full_input_txt, model_choice], outputs=[full_output_txt])
|
98 |
|
99 |
gr.Examples(
|
100 |
examples=[
|
101 |
+
"ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "TD2"
|
102 |
],
|
103 |
inputs=full_input_txt,
|
104 |
outputs=full_output_txt,
|
|
|
130 |
)
|
131 |
|
132 |
partial_btn = gr.Button(value="Shakkel")
|
133 |
+
partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider, model_choice], outputs=[partial_output_txt])
|
134 |
|
135 |
gr.Examples(
|
136 |
examples=[
|
137 |
+
["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0, "TD2"],
|
138 |
+
["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Soft", 0.1, "TD2"],
|
139 |
+
["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Soft", 0.01, "TD2"],
|
140 |
],
|
141 |
inputs=[partial_input_txt, masking_mode, threshold_slider],
|
142 |
outputs=partial_output_txt,
|
|
|
144 |
cache_examples=True,
|
145 |
)
|
146 |
|
147 |
+
gr.Markdown(
|
148 |
+
"""
|
149 |
+
### Abstract
|
150 |
+
> Diacritization plays a pivotal role in improving readability and disambiguating the meaning of Arabic texts. Efforts have so far focused on marking every eligible character (Full Diacritization). Comparatively overlooked, Partial Diacritzation (PD) is the selection of a subset of characters to be marked to aid comprehension where needed.Research has indicated that excessive diacritic marks can hinder skilled readers---reducing reading speed and accuracy. We conduct a behavioral experiment and show that partially marked text is often easier to read than fully marked text, and sometimes easier than plain text. In this light, we introduce Context-Contrastive Partial Diacritization (CCPD)---a novel approach to PD which integrates seamlessly with existing Arabic diacritization systems. CCPD processes each word twice, once with context and once without, and diacritizes only the characters with disparities between the two inferences. Further, we introduce novel indicators for measuring partial diacritization quality {SR, PDER, HDER, ERE}, essential for establishing this as a machine learning task. Lastly, we introduce TD2, a Transformer-variant of an established model which offers a markedly different performance profile on our proposed indicators compared to all other known systems.
|
151 |
+
""")
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
demo.queue().launch(
|
config.yaml
CHANGED
@@ -1,22 +1,47 @@
|
|
1 |
run-title: tashkeela-d2
|
2 |
debug: false
|
|
|
3 |
|
4 |
paths:
|
5 |
base: ./dataset/ashaar
|
6 |
save: ./models
|
7 |
load: tashkeela-d2.pt
|
|
|
8 |
resume: ./models/Tashkeela-D2/tashkeela-d2.pt
|
9 |
constants: ./dataset/helpers/constants
|
10 |
word-embs: vocab.vec
|
11 |
test: test
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
loader:
|
14 |
wembs-limit: -1
|
15 |
num-workers: 0
|
16 |
|
17 |
train:
|
18 |
epochs: 1000
|
19 |
-
batch-size:
|
20 |
char-embed-dim: 32
|
21 |
resume: false
|
22 |
resume-lr: false
|
@@ -51,7 +76,7 @@ train:
|
|
51 |
stopping-patience: 3
|
52 |
|
53 |
predictor:
|
54 |
-
batch-size:
|
55 |
stride: 2
|
56 |
window: 20
|
57 |
gt-signal-prob: 0
|
|
|
1 |
run-title: tashkeela-d2
|
2 |
debug: false
|
3 |
+
model-name: TD2
|
4 |
|
5 |
paths:
|
6 |
base: ./dataset/ashaar
|
7 |
save: ./models
|
8 |
load: tashkeela-d2.pt
|
9 |
+
load-td2: td2/tashkeela-ashaar-td2.pt
|
10 |
resume: ./models/Tashkeela-D2/tashkeela-d2.pt
|
11 |
constants: ./dataset/helpers/constants
|
12 |
word-embs: vocab.vec
|
13 |
test: test
|
14 |
|
15 |
+
modeling:
|
16 |
+
"checkpoint": munael/Partial-Arabic-Diacritization-TD2
|
17 |
+
"base_model": CAMeL-Lab/bert-base-arabic-camelbert-mix-ner
|
18 |
+
# "base_model": UBC-NLP/MARBERTv2
|
19 |
+
# "base_model": UBC-NLP/ARBERTv2
|
20 |
+
"deep-cls": true
|
21 |
+
"full-finetune": true #< From true
|
22 |
+
"keep-token-model-layers": 2
|
23 |
+
# "num-finetune-last-layers": 2 #
|
24 |
+
"num-chars": 40
|
25 |
+
"char-embed-dim": 128
|
26 |
+
"token_hidden_size": 768
|
27 |
+
"deep-down-proj": true
|
28 |
+
"dropout": 0.2
|
29 |
+
"sentence_dropout": 0.1
|
30 |
+
"diac_model_config": {
|
31 |
+
"vocab_size": 1,
|
32 |
+
"num_hidden_layers": 2,
|
33 |
+
"hidden_size": 768,
|
34 |
+
"intermediate_size": 2304,
|
35 |
+
"num_attention_heads": 8,
|
36 |
+
}
|
37 |
+
|
38 |
loader:
|
39 |
wembs-limit: -1
|
40 |
num-workers: 0
|
41 |
|
42 |
train:
|
43 |
epochs: 1000
|
44 |
+
batch-size: 1
|
45 |
char-embed-dim: 32
|
46 |
resume: false
|
47 |
resume-lr: false
|
|
|
76 |
stopping-patience: 3
|
77 |
|
78 |
predictor:
|
79 |
+
batch-size: 1
|
80 |
stride: 2
|
81 |
window: 20
|
82 |
gt-signal-prob: 0
|
dataloader.py
CHANGED
@@ -24,7 +24,7 @@ class DataRetriever(Dataset):
|
|
24 |
|
25 |
def __getitem__(self, idx):
|
26 |
word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
|
27 |
-
return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long)
|
28 |
|
29 |
def create_sentence(self, idx):
|
30 |
line = self.lines[idx]
|
|
|
24 |
|
25 |
def __getitem__(self, idx):
|
26 |
word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
|
27 |
+
return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long), [0]
|
28 |
|
29 |
def create_sentence(self, idx):
|
30 |
line = self.lines[idx]
|
dataloader_plm.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Any
|
2 |
+
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
from pyarabic.araby import tokenize, strip_tashkeel
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as T
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
|
12 |
+
try:
|
13 |
+
from transformers import PreTrainedTokenizer
|
14 |
+
except:
|
15 |
+
from typing import Any as PreTrainedTokenizer
|
16 |
+
|
17 |
+
from data_utils import DatasetUtils
|
18 |
+
import diac_utils as du
|
19 |
+
|
20 |
+
class DataRetriever(Dataset):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
lines,
|
24 |
+
data_utils: DatasetUtils,
|
25 |
+
is_test: bool = False,
|
26 |
+
*,
|
27 |
+
tokenizer: PreTrainedTokenizer,
|
28 |
+
lines_mode: bool = False,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super(DataRetriever).__init__()
|
32 |
+
|
33 |
+
self.data_utils = data_utils
|
34 |
+
self.is_test = is_test
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
|
37 |
+
self.stride = data_utils.test_stride
|
38 |
+
|
39 |
+
self.data_points = lines
|
40 |
+
|
41 |
+
self.bos_token_id = int(self.tokenizer.bos_token_id or self.tokenizer.cls_token_id)
|
42 |
+
self.eos_token_id = int(self.tokenizer.eos_token_id or self.tokenizer.sep_token_id)
|
43 |
+
|
44 |
+
self.max_tokens = self.data_utils.max_token_count
|
45 |
+
self.max_slen = self.data_utils.max_sent_len
|
46 |
+
self.max_wlen = self.data_utils.max_word_len
|
47 |
+
# self.p_val = self.data_utils.pad_val
|
48 |
+
self.p_val = self.tokenizer.pad_token_id
|
49 |
+
self.pc_val = self.data_utils.pad_char_id
|
50 |
+
self.pt_val = self.data_utils.pad_target_val
|
51 |
+
|
52 |
+
self.char_x_padding = [self.pc_val] * self.max_wlen
|
53 |
+
self.diac_x_padding = [[self.pc_val]*8] * self.max_wlen
|
54 |
+
self.diac_y_padding = [self.pt_val] * self.max_wlen
|
55 |
+
|
56 |
+
def preprocess(self, data, dtype=T.long):
|
57 |
+
return [T.tensor(np.array(x), dtype=dtype) for x in data]
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.data_points)
|
61 |
+
|
62 |
+
@lru_cache(maxsize=1024 * 2)
|
63 |
+
def __getitem__(self, idx: int) -> Tuple[List[T.Tensor], T.Tensor, T.Tensor]:
|
64 |
+
word_x, char_x, diac_x, diac_y, subword_lengths = self.create_sentence(idx)
|
65 |
+
return (
|
66 |
+
self.preprocess([word_x, char_x, diac_x]),
|
67 |
+
T.tensor(diac_y, dtype=T.long),
|
68 |
+
T.tensor(subword_lengths, dtype=T.long)
|
69 |
+
)
|
70 |
+
|
71 |
+
def create_sentence(self, idx):
|
72 |
+
line = self.data_points[idx]
|
73 |
+
# tokens = tokenize(line.strip())
|
74 |
+
words: List[str] = tokenize(line.strip())
|
75 |
+
# words_: List[str] = []
|
76 |
+
# for word in words:
|
77 |
+
# if len(strip_tashkeel(word)) == 0:
|
78 |
+
# words_[-1] += word.strip()
|
79 |
+
# else:
|
80 |
+
# words_.append(word)
|
81 |
+
# word_tokens_bin = [self.tokenizer(word) for word in words]
|
82 |
+
# tokens_bin = self.tokenizer(line.strip())
|
83 |
+
|
84 |
+
subwords_x = [self.bos_token_id]
|
85 |
+
subword_lengths = []
|
86 |
+
|
87 |
+
char_x = []
|
88 |
+
diac_x = []
|
89 |
+
diac_y = []
|
90 |
+
diac_y_tmp = []
|
91 |
+
|
92 |
+
for i_word, word in enumerate(words):
|
93 |
+
word = du.strip_unknown_tashkeel(word)
|
94 |
+
word_chars = du.split_word_on_characters_with_diacritics(word)
|
95 |
+
cx, cy, cy_3head = du.create_label_for_word(word_chars)
|
96 |
+
|
97 |
+
word_strip = strip_tashkeel(word)
|
98 |
+
#? List[int: "word_index"]
|
99 |
+
#? Strip the BOS/EOS which the tokenizer adds
|
100 |
+
word_sub_ids = self.tokenizer(word_strip)['input_ids'][1:-1]
|
101 |
+
subword_lengths += [len(word_sub_ids)]
|
102 |
+
|
103 |
+
subwords_x += word_sub_ids
|
104 |
+
# word_x += [self.data_utils.w2idx.get(word_strip, self.data_utils.w2idx["<pad>"])]
|
105 |
+
|
106 |
+
char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.max_wlen)]
|
107 |
+
|
108 |
+
diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.max_wlen, pad=self.data_utils.pad_target_val)]
|
109 |
+
diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.max_wlen, pad=[self.data_utils.pad_target_val]*3)]
|
110 |
+
|
111 |
+
assert len(char_x) == len(subword_lengths), f"{char_x=}; {subword_lengths=} ;;"
|
112 |
+
assert len(char_x) == len(words)
|
113 |
+
|
114 |
+
diac_x = self.data_utils.create_decoder_input(diac_y_tmp)
|
115 |
+
|
116 |
+
subwords_x += [self.eos_token_id]
|
117 |
+
# assert len(char_x) + 2 == len(subwords_x), f"{len(char_x)} + 2 != {len(subwords_x)} ;;" # Because of BOS, EOS
|
118 |
+
assert len(subword_lengths) == len(words)
|
119 |
+
subwords_x = self.data_utils.pad_and_truncate_sequence(subwords_x, self.max_tokens, pad=self.p_val)
|
120 |
+
subword_lengths = self.data_utils.pad_and_truncate_sequence(subword_lengths, self.max_slen, pad=0)
|
121 |
+
|
122 |
+
char_x = self.data_utils.pad_and_truncate_sequence(char_x, self.max_slen, pad=self.char_x_padding)
|
123 |
+
diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, self.max_slen, pad=self.diac_x_padding)
|
124 |
+
diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, self.max_slen, pad=self.diac_y_padding)
|
125 |
+
|
126 |
+
return subwords_x, char_x, diac_x, diac_y, subword_lengths
|
model_partial.py
CHANGED
@@ -9,7 +9,7 @@ from torch.nn import functional as F
|
|
9 |
from diac_utils import flat_2_3head
|
10 |
|
11 |
from model_dd import DiacritizerD2
|
12 |
-
from
|
13 |
|
14 |
class Readout(nn.Module):
|
15 |
def __init__(
|
@@ -72,8 +72,11 @@ class PartialDD(nn.Module):
|
|
72 |
# self.config_d2 = yaml.safe_load(fin)
|
73 |
# self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
74 |
self.config = config
|
75 |
-
self._use_d2 =
|
76 |
-
|
|
|
|
|
|
|
77 |
|
78 |
# self.sentence_diac.to(self.device)
|
79 |
# self.build()
|
@@ -90,9 +93,10 @@ class PartialDD(nn.Module):
|
|
90 |
|
91 |
def load_state_dict(
|
92 |
self,
|
93 |
-
state_dict: dict
|
|
|
94 |
):
|
95 |
-
self.sentence_diac.load_state_dict(state_dict)
|
96 |
|
97 |
def _slim_batch(
|
98 |
self,
|
@@ -277,7 +281,7 @@ class PartialDD(nn.Module):
|
|
277 |
}
|
278 |
print("> Predicting...")
|
279 |
# breakpoint()
|
280 |
-
for i_batch, (inputs, _) in enumerate(tqdm(dataloader)):
|
281 |
# if i_batch > 10:
|
282 |
# break
|
283 |
#^ inputs: [toke_ids, char_ids, diac_ids]
|
|
|
9 |
from diac_utils import flat_2_3head
|
10 |
|
11 |
from model_dd import DiacritizerD2
|
12 |
+
from model_plm import Diacritizer
|
13 |
|
14 |
class Readout(nn.Module):
|
15 |
def __init__(
|
|
|
72 |
# self.config_d2 = yaml.safe_load(fin)
|
73 |
# self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
|
74 |
self.config = config
|
75 |
+
self._use_d2 = config["model-name"] == "D2"
|
76 |
+
if self._use_d2:
|
77 |
+
self.sentence_diac = DiacritizerD2(self.config)
|
78 |
+
else:
|
79 |
+
self.sentence_diac = Diacritizer(self.config, load_pretrained=False)
|
80 |
|
81 |
# self.sentence_diac.to(self.device)
|
82 |
# self.build()
|
|
|
93 |
|
94 |
def load_state_dict(
|
95 |
self,
|
96 |
+
state_dict: dict,
|
97 |
+
strict: bool = True,
|
98 |
):
|
99 |
+
self.sentence_diac.load_state_dict(state_dict, strict=strict)
|
100 |
|
101 |
def _slim_batch(
|
102 |
self,
|
|
|
281 |
}
|
282 |
print("> Predicting...")
|
283 |
# breakpoint()
|
284 |
+
for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)):
|
285 |
# if i_batch > 10:
|
286 |
# break
|
287 |
#^ inputs: [toke_ids, char_ids, diac_ids]
|
model_plm.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Iterator, cast
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch as T
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import BertConfig, BertModel
|
10 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
11 |
+
from transformers import PreTrainedModel
|
12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
13 |
+
|
14 |
+
class Diacritizer(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
config,
|
18 |
+
device=None,
|
19 |
+
load_pretrained=True
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self._dummy = nn.Parameter(T.ones(1))
|
23 |
+
|
24 |
+
if 'modeling' in config:
|
25 |
+
config = config['modeling']
|
26 |
+
self.config = config
|
27 |
+
|
28 |
+
model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner")
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
30 |
+
|
31 |
+
if load_pretrained:
|
32 |
+
self.token_model: BertModel = AutoModel.from_pretrained(model_name)
|
33 |
+
else:
|
34 |
+
marbert_config = AutoConfig.from_pretrained(model_name)
|
35 |
+
self.token_model = AutoModel.from_config(marbert_config)
|
36 |
+
|
37 |
+
self.num_classes = 15
|
38 |
+
self.diac_model_config = BertConfig(**config['diac_model_config'])
|
39 |
+
self.token_model_config: BertConfig = self.token_model.config
|
40 |
+
|
41 |
+
self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"])
|
42 |
+
self.diac_emb_model = self.build_diac_model(self.token_model)
|
43 |
+
|
44 |
+
self.down_project_token_embeds_deep = None
|
45 |
+
self.down_project_token_embeds = None
|
46 |
+
if 'token_hidden_size' in config:
|
47 |
+
if config['token_hidden_size'] == 'auto':
|
48 |
+
down_proj_size = self.diac_emb_model.config.hidden_size
|
49 |
+
else:
|
50 |
+
down_proj_size = config['token_hidden_size']
|
51 |
+
if config.get('deep-down-proj', False):
|
52 |
+
self.down_project_token_embeds_deep = nn.Sequential(
|
53 |
+
nn.Linear(
|
54 |
+
self.token_model_config.hidden_size + config["char-embed-dim"],
|
55 |
+
down_proj_size * 4,
|
56 |
+
bias=False,
|
57 |
+
),
|
58 |
+
nn.Tanh(),
|
59 |
+
nn.Linear(
|
60 |
+
down_proj_size * 4,
|
61 |
+
down_proj_size,
|
62 |
+
bias=False,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
# else:
|
66 |
+
self.down_project_token_embeds = nn.Linear(
|
67 |
+
self.token_model_config.hidden_size + config["char-embed-dim"],
|
68 |
+
down_proj_size,
|
69 |
+
bias=False,
|
70 |
+
)
|
71 |
+
|
72 |
+
# assert self.down_project_token_embeds_deep is None or self.down_project_token_embeds is None
|
73 |
+
classifier_feature_size = self.diac_model_config.hidden_size
|
74 |
+
if config.get('deep-cls', False):
|
75 |
+
# classifier_feature_size = 512
|
76 |
+
self.final_feature_transform = nn.Linear(
|
77 |
+
self.diac_model_config.hidden_size
|
78 |
+
+ self.token_model_config.hidden_size,
|
79 |
+
#^ diac_features + [residual from token_model]
|
80 |
+
out_features=classifier_feature_size,
|
81 |
+
bias=False
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
self.final_feature_transform = None
|
85 |
+
|
86 |
+
self.feature_layer_norm = nn.LayerNorm(classifier_feature_size)
|
87 |
+
self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True)
|
88 |
+
|
89 |
+
self.trim_model_(config)
|
90 |
+
|
91 |
+
self.dropout = nn.Dropout(config['dropout'])
|
92 |
+
self.sent_dropout_p = config['sentence_dropout']
|
93 |
+
self.closs = F.cross_entropy
|
94 |
+
|
95 |
+
def build_diac_model(self, token_model=None):
|
96 |
+
if self.config.get('pre-init-diac-model', False):
|
97 |
+
model = copy.deepcopy(self.token_model)
|
98 |
+
model.pooler = None
|
99 |
+
model.embeddings.word_embeddings = None
|
100 |
+
|
101 |
+
num_layers = self.config.get('keep-token-model-layers', None)
|
102 |
+
model.encoder.layer = nn.ModuleList(
|
103 |
+
list(model.encoder.layer[num_layers:num_layers*2])
|
104 |
+
)
|
105 |
+
|
106 |
+
model.encoder.config.num_hidden_layers = num_layers
|
107 |
+
else:
|
108 |
+
model = BertModel(self.diac_model_config)
|
109 |
+
return model
|
110 |
+
|
111 |
+
def trim_model_(self, config):
|
112 |
+
self.token_model.pooler = None
|
113 |
+
self.diac_emb_model.pooler = None
|
114 |
+
# self.diac_emb_model.embeddings = None
|
115 |
+
self.diac_emb_model.embeddings.word_embeddings = None
|
116 |
+
|
117 |
+
num_token_model_kept_layers = config.get('keep-token-model-layers', None)
|
118 |
+
if num_token_model_kept_layers is not None:
|
119 |
+
self.token_model.encoder.layer = nn.ModuleList(
|
120 |
+
list(self.token_model.encoder.layer[:num_token_model_kept_layers])
|
121 |
+
)
|
122 |
+
self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers
|
123 |
+
|
124 |
+
if not config.get('full-finetune', False):
|
125 |
+
for param in self.token_model.parameters():
|
126 |
+
param.requires_grad = False
|
127 |
+
finetune_last_layers = config.get('num-finetune-last-layers', 4)
|
128 |
+
if finetune_last_layers > 0:
|
129 |
+
unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:]
|
130 |
+
for layer in unfrozen_layers:
|
131 |
+
for param in layer.parameters():
|
132 |
+
param.requires_grad = True
|
133 |
+
|
134 |
+
def get_grouped_params(self):
|
135 |
+
downstream_params: Iterator[nn.Parameter] = cast(
|
136 |
+
Iterator,
|
137 |
+
(param
|
138 |
+
for module in (self.diac_emb_model, self.classifier, self.char_embs)
|
139 |
+
for param in module.parameters())
|
140 |
+
)
|
141 |
+
pg = {
|
142 |
+
'pretrained': self.token_model.parameters(),
|
143 |
+
'downstream': downstream_params,
|
144 |
+
}
|
145 |
+
return pg
|
146 |
+
|
147 |
+
@property
|
148 |
+
def device(self):
|
149 |
+
return self._dummy.device
|
150 |
+
|
151 |
+
def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None):
|
152 |
+
# ^ word_x, char_x, diac_x are Indices
|
153 |
+
# ^ xt : self.preprocess((word_x, char_x, diac_x)),
|
154 |
+
# ^ yt : T.tensor(diac_y, dtype=T.long),
|
155 |
+
# ^ subword_lengths: T.tensor(subword_lengths, dtype=T.long)
|
156 |
+
#< Move char_x, diac_x to device because they're small and trainable
|
157 |
+
xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths)
|
158 |
+
xt[0] = xt[0].to(self.device)
|
159 |
+
xt[1] = xt[1].to(self.device)
|
160 |
+
# xt[2] = xt[2].to(self.device)
|
161 |
+
|
162 |
+
yt = yt.to(self.device)
|
163 |
+
#^ yt: [b tw tc]
|
164 |
+
|
165 |
+
Nb, Tword, Tchar = xt[1].shape
|
166 |
+
if Tword * Tchar < 500:
|
167 |
+
diac = self(*xt, subword_lengths)
|
168 |
+
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
|
169 |
+
else:
|
170 |
+
num_chunks = Tword * Tchar / 300
|
171 |
+
loss = 0
|
172 |
+
for i in range(round(num_chunks+0.5)):
|
173 |
+
_slice = slice(i*300, (i+1)*300)
|
174 |
+
chunk = self._slice_batch(xt, _slice)
|
175 |
+
diac = self(*chunk, subword_lengths[_slice])
|
176 |
+
chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum')
|
177 |
+
loss = loss + chunk_loss
|
178 |
+
|
179 |
+
return loss
|
180 |
+
|
181 |
+
def _slice_batch(self, xt: List[T.Tensor], _slice):
|
182 |
+
return [xt[0][_slice], xt[1][_slice], xt[2][_slice]]
|
183 |
+
|
184 |
+
def _slim_batch_size(
|
185 |
+
self,
|
186 |
+
tx: T.Tensor,
|
187 |
+
cx: T.Tensor,
|
188 |
+
yt: T.Tensor,
|
189 |
+
subword_lengths: T.Tensor
|
190 |
+
):
|
191 |
+
#^ tx : [b tt]
|
192 |
+
#^ cx : [b tw tc]
|
193 |
+
#^ yt : [b tw tc]
|
194 |
+
token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id)
|
195 |
+
Ttoken = token_nonpad_mask.sum(1).max()
|
196 |
+
tx = tx[:, :Ttoken]
|
197 |
+
|
198 |
+
char_nonpad_mask = cx.ne(0)
|
199 |
+
Tword = char_nonpad_mask.any(2).sum(1).max()
|
200 |
+
Tchar = char_nonpad_mask.sum(2).max()
|
201 |
+
cx = cx[:, :Tword, :Tchar]
|
202 |
+
yt = yt[:, :Tword, :Tchar]
|
203 |
+
subword_lengths = subword_lengths[:, :Tword]
|
204 |
+
|
205 |
+
return tx, cx, yt, subword_lengths
|
206 |
+
|
207 |
+
def token_dropout(self, toke_x):
|
208 |
+
#^ toke_x : [b tw]
|
209 |
+
if self.training:
|
210 |
+
q = 1.0 - self.sent_dropout_p
|
211 |
+
sdo = T.bernoulli(T.full(toke_x.shape, q))
|
212 |
+
toke_x[sdo == 0] = self.tokenizer.pad_token_id
|
213 |
+
return toke_x
|
214 |
+
|
215 |
+
def sentence_dropout(self, word_embs: T.Tensor):
|
216 |
+
#^ word_embs : [b tw dwe]
|
217 |
+
if self.training:
|
218 |
+
q = 1.0 - self.sent_dropout_p
|
219 |
+
sdo = T.bernoulli(T.full(word_embs.shape[:2], q))
|
220 |
+
sdo = sdo.detach().unsqueeze(-1).to(word_embs)
|
221 |
+
word_embs = word_embs * sdo
|
222 |
+
# toke_x[sdo == 0] = self.tokenizer.pad_token_id
|
223 |
+
return word_embs
|
224 |
+
|
225 |
+
def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor):
|
226 |
+
y: BaseModelOutputWithPoolingAndCrossAttentions
|
227 |
+
y = self.token_model(input_ids, attention_mask=attention_mask)
|
228 |
+
z = y.last_hidden_state
|
229 |
+
return z
|
230 |
+
|
231 |
+
def forward(
|
232 |
+
self,
|
233 |
+
toke_x : T.Tensor,
|
234 |
+
char_x : T.Tensor,
|
235 |
+
diac_x : T.Tensor,
|
236 |
+
subword_lengths : T.Tensor,
|
237 |
+
):
|
238 |
+
#^ toke_x : [b tt]
|
239 |
+
#^ char_x : [b tw tc]
|
240 |
+
#^ diac_x/labels : [b tw tc]
|
241 |
+
#^ subword_lengths : [b, tw]
|
242 |
+
# !TODO Use `subword_lengths` to aggregate subword embeddings first before ...
|
243 |
+
# ... passing concatenated contextual embedding to chars in diac_model
|
244 |
+
|
245 |
+
token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id)
|
246 |
+
char_nonpad_mask = char_x.ne(0)
|
247 |
+
|
248 |
+
Nb, Tw, Tc = char_x.shape
|
249 |
+
# assert Tw == Tw_0 and Tc == Tc_0, f"{Tw=} {Tw_0=}, {Tc=} {Tc_0=}"
|
250 |
+
|
251 |
+
# toke_x = self.token_dropout(toke_x)
|
252 |
+
token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask)
|
253 |
+
# token_embs = self.sentence_dropout(token_embs)
|
254 |
+
#? Strip BOS,EOS
|
255 |
+
token_embs = token_embs[:, 1:-1, ...]
|
256 |
+
|
257 |
+
sent_word_strides = subword_lengths.cumsum(1)
|
258 |
+
sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs)
|
259 |
+
for i_b in range(Nb):
|
260 |
+
token_embs_ib = token_embs[i_b]
|
261 |
+
start_iw = 0
|
262 |
+
for i_word, end_iw in enumerate(sent_word_strides[i_b]):
|
263 |
+
if end_iw == start_iw: break
|
264 |
+
word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw)
|
265 |
+
sent_enc[i_b, i_word] = word_emb
|
266 |
+
start_iw = end_iw
|
267 |
+
#^ sent_enc: [b tw dwe]
|
268 |
+
|
269 |
+
char_x_flat = char_x.reshape(Nb*Tw, Tc)
|
270 |
+
char_nonpad_mask = char_x_flat.gt(0)
|
271 |
+
# ^ char_nonpad_mask [b*tw tc]
|
272 |
+
|
273 |
+
char_x_flat = char_x_flat * char_nonpad_mask
|
274 |
+
|
275 |
+
cembs = self.char_embs(char_x_flat)
|
276 |
+
|
277 |
+
#^ cembs: [b*tw tc dce]
|
278 |
+
wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1)
|
279 |
+
#^ wembs: [b tw dwe] => [b tw _ dwe] => [b*tw tc dwe]
|
280 |
+
cw_embs = T.cat([cembs, wembs], dim=-1)
|
281 |
+
#^ char_embs : [b*tw tc dcw] ; dcw = dc + dwe
|
282 |
+
cw_embs = self.dropout(cw_embs)
|
283 |
+
|
284 |
+
cw_embs_ = cw_embs
|
285 |
+
if self.down_project_token_embeds is not None:
|
286 |
+
cw_embs_ = self.down_project_token_embeds(cw_embs)
|
287 |
+
if self.down_project_token_embeds_deep is not None:
|
288 |
+
cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs)
|
289 |
+
cw_embs = cw_embs_
|
290 |
+
|
291 |
+
diac_enc: BaseModelOutputWithPoolingAndCrossAttentions
|
292 |
+
diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask)
|
293 |
+
diac_emb = diac_enc.last_hidden_state
|
294 |
+
diac_emb = self.dropout(diac_emb)
|
295 |
+
#^ diac_emb: [b*tw tc dce]
|
296 |
+
diac_emb = diac_emb.view(Nb, Tw, Tc, -1)
|
297 |
+
|
298 |
+
sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1)
|
299 |
+
final_feature = T.cat([sent_residual, diac_emb], dim=-1)
|
300 |
+
if self.final_feature_transform is not None:
|
301 |
+
final_feature = self.final_feature_transform(final_feature)
|
302 |
+
final_feature = F.tanh(final_feature)
|
303 |
+
final_feature = self.dropout(final_feature)
|
304 |
+
else:
|
305 |
+
final_feature = diac_emb
|
306 |
+
|
307 |
+
# final_feature = self.feature_layer_norm(final_feature)
|
308 |
+
diac_out = self.classifier(final_feature)
|
309 |
+
# if T.isnan(diac_out).any():
|
310 |
+
# breakpoint()
|
311 |
+
return diac_out
|
312 |
+
|
313 |
+
def predict(self, dataloader):
|
314 |
+
from tqdm import tqdm
|
315 |
+
import diac_utils as du
|
316 |
+
training = self.training
|
317 |
+
self.eval()
|
318 |
+
|
319 |
+
preds = {'haraka': [], 'shadda': [], 'tanween': []}
|
320 |
+
print("> Predicting...")
|
321 |
+
for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)):
|
322 |
+
inputs[0] = inputs[0].to(self.device)
|
323 |
+
inputs[1] = inputs[1].to(self.device)
|
324 |
+
output = self(*inputs, subword_lengths).detach()
|
325 |
+
|
326 |
+
marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1)
|
327 |
+
#^ [b ts tw]
|
328 |
+
|
329 |
+
haraka, tanween, shadda = du.flat_2_3head(marks)
|
330 |
+
|
331 |
+
preds['haraka'].extend(haraka)
|
332 |
+
preds['tanween'].extend(tanween)
|
333 |
+
preds['shadda'].extend(shadda)
|
334 |
+
|
335 |
+
self.train(training)
|
336 |
+
return (
|
337 |
+
np.array(preds['haraka']),
|
338 |
+
np.array(preds["tanween"]),
|
339 |
+
np.array(preds["shadda"]),
|
340 |
+
)
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
model = Diacritizer({
|
344 |
+
"num-chars": 36,
|
345 |
+
"hidden_size": 768,
|
346 |
+
"char-embed-dim": 32,
|
347 |
+
"dropout": 0.25,
|
348 |
+
"sentence_dropout": 0.2,
|
349 |
+
"diac_model_config": {
|
350 |
+
"num_layers": 4,
|
351 |
+
"hidden_size": 768 + 32,
|
352 |
+
"intermediate_size": (768 + 32) * 4,
|
353 |
+
},
|
354 |
+
}, load_pretrained=False)
|
355 |
+
|
356 |
+
total_params = sum(p.numel() for p in model.parameters())
|
357 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
358 |
+
|
359 |
+
print(model)
|
360 |
+
print(f"{trainable_params:,}/{total_params:,} Trainable Parameters")
|
predict.py
CHANGED
@@ -14,9 +14,9 @@ from torch.utils.data import DataLoader
|
|
14 |
|
15 |
from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
|
16 |
from model_partial import PartialDD
|
17 |
-
from model_dd import DiacritizerD2
|
18 |
from data_utils import DatasetUtils
|
19 |
from dataloader import DataRetriever
|
|
|
20 |
from segment import segment
|
21 |
|
22 |
from partial_dd_metrics import (
|
@@ -105,15 +105,19 @@ class Predictor:
|
|
105 |
config['predictor'].get('device', 'cuda:0')
|
106 |
if T.cuda.is_available() else 'cpu'
|
107 |
)
|
108 |
-
|
109 |
self.model = PartialDD(config)
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
self.model.to(self.device)
|
114 |
self.model.eval()
|
115 |
|
116 |
-
def create_dataloader(self, text, do_partial, do_hard_mask, threshold):
|
117 |
self.threshold = threshold
|
118 |
self.do_hard_mask = do_hard_mask
|
119 |
|
@@ -137,7 +141,12 @@ class Predictor:
|
|
137 |
self.original_lines = text.split('\n')
|
138 |
|
139 |
self.data_loader = DataLoader(
|
140 |
-
DataRetriever(self.data_utils, segments)
|
|
|
|
|
|
|
|
|
|
|
141 |
batch_size=self.config["predictor"].get("batch-size", 32),
|
142 |
shuffle=False,
|
143 |
num_workers=self.config['loader'].get('num-workers', 0),
|
|
|
14 |
|
15 |
from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
|
16 |
from model_partial import PartialDD
|
|
|
17 |
from data_utils import DatasetUtils
|
18 |
from dataloader import DataRetriever
|
19 |
+
from dataloader_plm import DataRetriever as DataRetrieverPLM
|
20 |
from segment import segment
|
21 |
|
22 |
from partial_dd_metrics import (
|
|
|
105 |
config['predictor'].get('device', 'cuda:0')
|
106 |
if T.cuda.is_available() else 'cpu'
|
107 |
)
|
108 |
+
|
109 |
self.model = PartialDD(config)
|
110 |
+
if config["model-name"] == "D2":
|
111 |
+
self.model.sentence_diac.build(word_embeddings, vocab_size)
|
112 |
+
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
|
113 |
+
else:
|
114 |
+
state_dict = T.load(config["paths"]["load-td2"], map_location=T.device(self.device))['state_dict']
|
115 |
+
|
116 |
+
self.model.load_state_dict(state_dict, strict=False)
|
117 |
self.model.to(self.device)
|
118 |
self.model.eval()
|
119 |
|
120 |
+
def create_dataloader(self, text, do_partial, do_hard_mask, threshold, model_name):
|
121 |
self.threshold = threshold
|
122 |
self.do_hard_mask = do_hard_mask
|
123 |
|
|
|
141 |
self.original_lines = text.split('\n')
|
142 |
|
143 |
self.data_loader = DataLoader(
|
144 |
+
DataRetriever(self.data_utils, segments)
|
145 |
+
if model_name == "D2"
|
146 |
+
else DataRetrieverPLM(segments, self.data_utils,
|
147 |
+
is_test=True,
|
148 |
+
tokenizer=self.model.tokenizer
|
149 |
+
),
|
150 |
batch_size=self.config["predictor"].get("batch-size", 32),
|
151 |
shuffle=False,
|
152 |
num_workers=self.config['loader'].get('num-workers', 0),
|