Spaces:
Runtime error
Runtime error
from typing import Iterable, Union, Tuple | |
from collections import Counter | |
import argparse | |
import os | |
import yaml | |
from pyarabic.araby import tokenize, strip_tatweel, strip_tashkeel | |
from tqdm import tqdm | |
import numpy as np | |
import torch as T | |
from torch.utils.data import DataLoader | |
from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head | |
from model_partial import PartialDD | |
from data_utils import DatasetUtils | |
from dataloader import DataRetriever | |
from dataloader_plm import DataRetriever as DataRetrieverPLM | |
from segment import segment | |
from partial_dd_metrics import ( | |
parse_data, | |
load_data, | |
make_mask_hard, | |
make_mask_logits, | |
) | |
def apply_tashkeel( | |
line: str, | |
diacs: Union[np.ndarray, T.Tensor] | |
): | |
line_w_diacs = "" | |
ts, tw = diacs.shape | |
diacs = diacs.flatten() | |
diacs_h3 = flat2_3head(diacs) | |
diacs_h3 = tuple(x.reshape(ts, tw) for x in diacs_h3) | |
diac_char_idx = 0 | |
diac_word_idx = 0 | |
for ch in line: | |
line_w_diacs += ch | |
if ch == " ": | |
diac_char_idx = 0 | |
diac_word_idx += 1 | |
else: | |
tashkeel = (diacs_h3[0][diac_word_idx][diac_char_idx], diacs_h3[1][diac_word_idx][diac_char_idx], diacs_h3[2][diac_word_idx][diac_char_idx]) | |
diac_char_idx += 1 | |
line_w_diacs += shakkel_char(*tashkeel) | |
return line_w_diacs | |
def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1): | |
mode = selection_mode | |
if mode == 'contrastive-hard': | |
# model_output_base = parse_data(data_base)[0] | |
# model_output_ctxt = parse_data(data_ctxt)[0] | |
# diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0) | |
diacritics = np.where( | |
make_mask_hard(model_output_ctxt, model_output_base), | |
model_output_ctxt.argmax(-1), | |
0, | |
).astype(int) | |
else: | |
# model_output_base = parse_data(data_base, logits=True, side='base')[2] | |
# model_output_ctxt = parse_data(data_ctxt, logits=True, side='ctxt')[2] | |
diacritics = np.where( | |
make_mask_logits( | |
model_output_ctxt, model_output_base, | |
version=mode, threshold=threshold, | |
), | |
model_output_ctxt.argmax(-1), | |
0, | |
).astype(int) | |
#^ shape: [b, tc | ClassId] | |
diacs_pred = model_output_base | |
assert len(diacs_pred) == len(data) | |
data = [ | |
' '.join(tokenize( | |
line.strip(), | |
morphs=[strip_tashkeel, strip_tatweel] | |
)) | |
for line in data | |
] | |
output = [] | |
for line, line_diacs in zip( | |
tqdm(data), | |
diacritics | |
): | |
line = apply_tashkeel(line, line_diacs) | |
output.append(line) | |
return output | |
class Predictor: | |
def __init__(self, config): | |
self.data_utils = DatasetUtils(config) | |
vocab_size = len(self.data_utils.letter_list) | |
word_embeddings = self.data_utils.embeddings | |
self.config = config | |
self.device = T.device( | |
config['predictor'].get('device', 'cuda:0') | |
if T.cuda.is_available() else 'cpu' | |
) | |
self.model = PartialDD(config) | |
if config["model-name"] == "D2": | |
self.model.sentence_diac.build(word_embeddings, vocab_size) | |
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict'] | |
else: | |
state_dict = T.load(config["paths"]["load-td2"], map_location=T.device(self.device))['state_dict'] | |
self.model.load_state_dict(state_dict, strict=False) | |
self.model.to(self.device) | |
self.model.eval() | |
def create_dataloader(self, text, do_partial, do_hard_mask, threshold, model_name): | |
self.threshold = threshold | |
self.do_hard_mask = do_hard_mask | |
stride = self.config["segment"]["stride"] | |
window = self.config["segment"]["window"] | |
min_window = self.config["segment"]["min-window"] | |
if self.do_hard_mask or not do_partial: | |
segments, mapping = segment([text], stride, window, min_window) | |
mapping_lines = [] | |
for sent_idx, seg_idx, word_idx, char_idx in mapping: | |
mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"] | |
self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines) | |
self.original_lines = [text] | |
self.segments = segments | |
else: | |
segments = text.split('\n') | |
self.segments = segments | |
self.original_lines = text.split('\n') | |
self.data_loader = DataLoader( | |
DataRetriever(self.data_utils, segments) | |
if model_name == "D2" | |
else DataRetrieverPLM(segments, self.data_utils, | |
is_test=True, | |
tokenizer=self.model.tokenizer | |
), | |
batch_size=self.config["predictor"].get("batch-size", 32), | |
shuffle=False, | |
num_workers=self.config['loader'].get('num-workers', 0), | |
) | |
class PredictTri(Predictor): | |
def __init__(self, config): | |
super().__init__(config) | |
self.diacritics = { | |
"FATHA": 1, | |
"KASRA": 2, | |
"DAMMA": 3, | |
"SUKUN": 4 | |
} | |
self.votes: Union[Counter[int], Counter[bool]] = Counter() | |
def count_votes( | |
self, | |
things: Union[Iterable[int], Iterable[bool]] | |
): | |
self.votes.clear() | |
self.votes.update(things) | |
return self.votes.most_common(1)[0][0] | |
def predict_majority_vote(self): | |
y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader) | |
diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) | |
return diacritized_lines | |
def predict_partial(self, do_partial, lines): | |
outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial) | |
if self.do_hard_mask or not do_partial: | |
y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics'] | |
diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda) | |
else: | |
diac_lines = diac_text(lines, outputs["other"][1], outputs["other"][0], selection_mode='1', threshold=self.threshold) | |
return '\n'.join(diac_lines) | |
def predict_majority_vote_context_contrastive(self, overwrite_cache=False): | |
assert isinstance(self.model, PartialDD) | |
if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache: | |
if not os.path.exists("dataset/cache"): | |
os.mkdir("dataset/cache") | |
# segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True) | |
segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='both') | |
T.save(segment_outputs, "dataset/cache/cache.pt") | |
else: | |
segment_outputs = T.load("dataset/cache/cache.pt") | |
y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics'] | |
diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority( | |
y_gen_diac, y_gen_tanween, y_gen_shadda, | |
) | |
extra_out = { | |
'line_data': { | |
**extra_for_lines, | |
}, | |
'segment_data': { | |
**segment_outputs, | |
# 'logits': segment_outputs['logits'], | |
} | |
} | |
return diacritized_lines, extra_out | |
def coalesce_votes_by_majority( | |
self, | |
y_gen_diac: np.ndarray, | |
y_gen_tanween: np.ndarray, | |
y_gen_shadda: np.ndarray, | |
): | |
prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines] | |
max_line_chars = max(len(line) for line in prepped_lines_og) | |
diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int) | |
count_processed_sents = 0 | |
do_break = False | |
diacritized_lines = [] | |
for sent_idx, line in enumerate(tqdm(prepped_lines_og)): | |
count_processed_sents = sent_idx + 1 | |
line = line.strip() | |
diacritized_line = "" | |
for char_idx, char in enumerate(line): | |
diacritized_line += char | |
char_vote_diacritic = [] | |
# ? This is the voting part | |
if sent_idx not in self.mapping: | |
continue | |
mapping_s_i = self.mapping[sent_idx] | |
for seg_idx in mapping_s_i: | |
if self.data_utils.debug and seg_idx >= 256: | |
do_break = True | |
break | |
mapping_g_i = mapping_s_i[seg_idx] | |
for t_idx in mapping_g_i: | |
mapping_t_i = mapping_g_i[t_idx] | |
if char_idx in mapping_t_i: | |
c_idx = mapping_t_i.index(char_idx) | |
output_idx = np.s_[seg_idx, t_idx, c_idx] | |
diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx]) | |
diac_char_i = HARAKAT_MAP.index(diac_h3) | |
if c_idx < 13 and diac_char_i != 0: | |
char_vote_diacritic.append(diac_char_i) | |
if do_break: | |
break | |
if len(char_vote_diacritic) > 0: | |
char_mv_diac = self.count_votes(char_vote_diacritic) | |
diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac]) | |
diacritics_pred[sent_idx, char_idx] = char_mv_diac | |
else: | |
diacritics_pred[sent_idx, char_idx] = 0 | |
if do_break: | |
break | |
diacritized_lines += [diacritized_line.strip()] | |
print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}') | |
extra = { | |
'diac_pred': diacritics_pred[:count_processed_sents], | |
} | |
return diacritized_lines, extra |