bkhmsi's picture
support for TD2
d7c4b94
from typing import Iterable, Union, Tuple
from collections import Counter
import argparse
import os
import yaml
from pyarabic.araby import tokenize, strip_tatweel, strip_tashkeel
from tqdm import tqdm
import numpy as np
import torch as T
from torch.utils.data import DataLoader
from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
from model_partial import PartialDD
from data_utils import DatasetUtils
from dataloader import DataRetriever
from dataloader_plm import DataRetriever as DataRetrieverPLM
from segment import segment
from partial_dd_metrics import (
parse_data,
load_data,
make_mask_hard,
make_mask_logits,
)
def apply_tashkeel(
line: str,
diacs: Union[np.ndarray, T.Tensor]
):
line_w_diacs = ""
ts, tw = diacs.shape
diacs = diacs.flatten()
diacs_h3 = flat2_3head(diacs)
diacs_h3 = tuple(x.reshape(ts, tw) for x in diacs_h3)
diac_char_idx = 0
diac_word_idx = 0
for ch in line:
line_w_diacs += ch
if ch == " ":
diac_char_idx = 0
diac_word_idx += 1
else:
tashkeel = (diacs_h3[0][diac_word_idx][diac_char_idx], diacs_h3[1][diac_word_idx][diac_char_idx], diacs_h3[2][diac_word_idx][diac_char_idx])
diac_char_idx += 1
line_w_diacs += shakkel_char(*tashkeel)
return line_w_diacs
def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1):
mode = selection_mode
if mode == 'contrastive-hard':
# model_output_base = parse_data(data_base)[0]
# model_output_ctxt = parse_data(data_ctxt)[0]
# diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0)
diacritics = np.where(
make_mask_hard(model_output_ctxt, model_output_base),
model_output_ctxt.argmax(-1),
0,
).astype(int)
else:
# model_output_base = parse_data(data_base, logits=True, side='base')[2]
# model_output_ctxt = parse_data(data_ctxt, logits=True, side='ctxt')[2]
diacritics = np.where(
make_mask_logits(
model_output_ctxt, model_output_base,
version=mode, threshold=threshold,
),
model_output_ctxt.argmax(-1),
0,
).astype(int)
#^ shape: [b, tc | ClassId]
diacs_pred = model_output_base
assert len(diacs_pred) == len(data)
data = [
' '.join(tokenize(
line.strip(),
morphs=[strip_tashkeel, strip_tatweel]
))
for line in data
]
output = []
for line, line_diacs in zip(
tqdm(data),
diacritics
):
line = apply_tashkeel(line, line_diacs)
output.append(line)
return output
class Predictor:
def __init__(self, config):
self.data_utils = DatasetUtils(config)
vocab_size = len(self.data_utils.letter_list)
word_embeddings = self.data_utils.embeddings
self.config = config
self.device = T.device(
config['predictor'].get('device', 'cuda:0')
if T.cuda.is_available() else 'cpu'
)
self.model = PartialDD(config)
if config["model-name"] == "D2":
self.model.sentence_diac.build(word_embeddings, vocab_size)
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
else:
state_dict = T.load(config["paths"]["load-td2"], map_location=T.device(self.device))['state_dict']
self.model.load_state_dict(state_dict, strict=False)
self.model.to(self.device)
self.model.eval()
def create_dataloader(self, text, do_partial, do_hard_mask, threshold, model_name):
self.threshold = threshold
self.do_hard_mask = do_hard_mask
stride = self.config["segment"]["stride"]
window = self.config["segment"]["window"]
min_window = self.config["segment"]["min-window"]
if self.do_hard_mask or not do_partial:
segments, mapping = segment([text], stride, window, min_window)
mapping_lines = []
for sent_idx, seg_idx, word_idx, char_idx in mapping:
mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
self.original_lines = [text]
self.segments = segments
else:
segments = text.split('\n')
self.segments = segments
self.original_lines = text.split('\n')
self.data_loader = DataLoader(
DataRetriever(self.data_utils, segments)
if model_name == "D2"
else DataRetrieverPLM(segments, self.data_utils,
is_test=True,
tokenizer=self.model.tokenizer
),
batch_size=self.config["predictor"].get("batch-size", 32),
shuffle=False,
num_workers=self.config['loader'].get('num-workers', 0),
)
class PredictTri(Predictor):
def __init__(self, config):
super().__init__(config)
self.diacritics = {
"FATHA": 1,
"KASRA": 2,
"DAMMA": 3,
"SUKUN": 4
}
self.votes: Union[Counter[int], Counter[bool]] = Counter()
def count_votes(
self,
things: Union[Iterable[int], Iterable[bool]]
):
self.votes.clear()
self.votes.update(things)
return self.votes.most_common(1)[0][0]
def predict_majority_vote(self):
y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader)
diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
return diacritized_lines
def predict_partial(self, do_partial, lines):
outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial)
if self.do_hard_mask or not do_partial:
y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics']
diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
else:
diac_lines = diac_text(lines, outputs["other"][1], outputs["other"][0], selection_mode='1', threshold=self.threshold)
return '\n'.join(diac_lines)
def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
assert isinstance(self.model, PartialDD)
if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache:
if not os.path.exists("dataset/cache"):
os.mkdir("dataset/cache")
# segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True)
segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='both')
T.save(segment_outputs, "dataset/cache/cache.pt")
else:
segment_outputs = T.load("dataset/cache/cache.pt")
y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics']
diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority(
y_gen_diac, y_gen_tanween, y_gen_shadda,
)
extra_out = {
'line_data': {
**extra_for_lines,
},
'segment_data': {
**segment_outputs,
# 'logits': segment_outputs['logits'],
}
}
return diacritized_lines, extra_out
def coalesce_votes_by_majority(
self,
y_gen_diac: np.ndarray,
y_gen_tanween: np.ndarray,
y_gen_shadda: np.ndarray,
):
prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines]
max_line_chars = max(len(line) for line in prepped_lines_og)
diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int)
count_processed_sents = 0
do_break = False
diacritized_lines = []
for sent_idx, line in enumerate(tqdm(prepped_lines_og)):
count_processed_sents = sent_idx + 1
line = line.strip()
diacritized_line = ""
for char_idx, char in enumerate(line):
diacritized_line += char
char_vote_diacritic = []
# ? This is the voting part
if sent_idx not in self.mapping:
continue
mapping_s_i = self.mapping[sent_idx]
for seg_idx in mapping_s_i:
if self.data_utils.debug and seg_idx >= 256:
do_break = True
break
mapping_g_i = mapping_s_i[seg_idx]
for t_idx in mapping_g_i:
mapping_t_i = mapping_g_i[t_idx]
if char_idx in mapping_t_i:
c_idx = mapping_t_i.index(char_idx)
output_idx = np.s_[seg_idx, t_idx, c_idx]
diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx])
diac_char_i = HARAKAT_MAP.index(diac_h3)
if c_idx < 13 and diac_char_i != 0:
char_vote_diacritic.append(diac_char_i)
if do_break:
break
if len(char_vote_diacritic) > 0:
char_mv_diac = self.count_votes(char_vote_diacritic)
diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac])
diacritics_pred[sent_idx, char_idx] = char_mv_diac
else:
diacritics_pred[sent_idx, char_idx] = 0
if do_break:
break
diacritized_lines += [diacritized_line.strip()]
print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}')
extra = {
'diac_pred': diacritics_pred[:count_processed_sents],
}
return diacritized_lines, extra