from typing import List import torch as T import numpy as np from pyarabic.araby import ( tokenize, strip_tashkeel, strip_tatweel, DIACRITICS ) SEPARATE_DIACRITICS = { "FATHA": 1, "KASRA": 2, "DAMMA": 3, "SUKUN": 4 } HARAKAT_MAP = [ #^ (haraka, tanween, shadda) (0,0,0), #< No diacs on char (1,0,0), (1,1,0), #< Tanween on 2nd slot (2,0,0), (2,1,0), (3,0,0), (3,1,0), (4,0,0), (0,0,1), #< shadda on 3rd slot (1,0,1), (1,1,1), (2,0,1), (2,1,1), (3,0,1), (3,1,1), (0,0,0), #< Padding == -1 (also for spaces) ] DIAC_PAD_IDX = -1 SPECIAL_TOKENS = ['', '', '', ''] LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي") CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ'] DIACRITICS_SHORT = [' ', 'َ', 'ً', 'ِ', 'ٍ', 'ُ', 'ٌ', 'ْ', 'ّ'] NUMBERS = list("0123456789") DELIMITERS = ["،","؛",",",";","«","»","{","}","(",")","[","]",".","*","-",":","?","!","؟"] UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT))) def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str: returned_text = "" if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]: returned_text += "\u0651" if diac == SEPARATE_DIACRITICS["FATHA"]: returned_text += "\u064E" if not tanween else "\u064B" elif diac == SEPARATE_DIACRITICS["KASRA"]: returned_text += "\u0650" if not tanween else "\u064D" elif diac == SEPARATE_DIACRITICS["DAMMA"]: returned_text += "\u064F" if not tanween else "\u064C" elif diac == SEPARATE_DIACRITICS["SUKUN"]: returned_text += "\u0652" return returned_text def diac_ids_of_line(line: str): diacs = [] words = tokenize(line) for word in words: word_chars = split_word_on_characters_with_diacritics(word) _cx, cy, _cy_3head = create_label_for_word(word_chars) diacs.extend(cy) diacs.append(DIAC_PAD_IDX) return np.array(diacs[:-1]) def strip_unknown_tashkeel(word: str): #! FIXME! warnings.warn("Stripping unknown tashkeel is disabled.") return word return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS) def create_gt_labels(lines): gt_labels = [] for line in lines: # gt_labels_line = [] # tokens = tokenize(line.strip()) # for w_idx, word in enumerate(tokens): # split_word = self.split_word_on_characters_with_diacritics(word) # _, cy_flat, _ = du.create_label_for_word(split_word) # gt_labels_line.extend(cy_flat) # if w_idx+1 < len(tokens): # gt_labels_line += [0] gt_labels_line = diac_ids_of_line(line) gt_labels.append(gt_labels_line) return gt_labels def split_word_on_characters_with_diacritics(word: str): ''' TODO! Make faster without deque and looping Returns: List[List[char: "letter or diacritic"]] ''' chars_w_diac = [] i_start = 0 for i_c, c in enumerate(word): #! FIXME! DIACRITICS_SHORT is missing a lot of less common diacritics ... #! which are then treated as letters during splitting. # if c not in DIACRITICS: if c not in DIACRITICS_SHORT: sub = list(word[i_start:i_c]) chars_w_diac.append(sub) i_start = i_c sub = list(word[i_start:]) if sub: chars_w_diac.append(sub) if not chars_w_diac[0]: chars_w_diac = chars_w_diac[1:] return chars_w_diac def load_lines(path: str, *, strip: bool): with open(path, 'r', encoding="utf-8", newline='\n') as fin: if strip: original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()] else: original_lines = [normalize_spaces(line) for line in fin.readlines()] return original_lines def normalize_spaces(line: str): return ' '.join(tokenize(line.strip())) def char_type(char: str): if char in LETTER_LIST: return LETTER_LIST.index(char) elif char in NUMBERS: return LETTER_LIST.index('') elif char in DELIMITERS: return LETTER_LIST.index('') else: return LETTER_LIST.index('') def create_labels(char_w_diac: str): remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4} char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:])) if len(char_w_diac) > 3: char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3] char_idx = None diacritic_index = None head_3 = None char_idx = char_type(char_w_diac[0]) diacs = set(char_w_diac[1:]) diac_h3 = [0, 0, 0] for diac in diacs: if diac in DIACRITICS_SHORT: diac_idx = DIACRITICS_SHORT.index(diac) if diac_idx in [2, 4, 6]: #< Tanween diac_h3[0] = remap_dict[diac_idx - 1] diac_h3[1] = 1 elif diac_idx == 8: #< shadda diac_h3[2] = 1 else: #< Haraka or sukoon diac_h3[0] = remap_dict[diac_idx] assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2])) diacritic_index = HARAKAT_MAP.index(tuple(diac_h3)) return char_idx, diacritic_index, diac_h3 if len(char_w_diac) == 1: return char_idx, 0, [remap_dict[0], 0, 0] elif len(char_w_diac) == 2: # If shadda OR diac diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) if diacritic_index in [2, 4, 6]: # list of tanween head_3 = [remap_dict[diacritic_index - 1], 1, 0] elif diacritic_index == 8: head_3 = [0, 0, 1] else: head_3 = [remap_dict[diacritic_index], 0, 0] elif len(char_w_diac) == 3: # If shadda AND diac if DIACRITICS_SHORT[8] == char_w_diac[1]: diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2]) else: diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) if diacritic_index in [2, 4, 6]: # list of tanween head_3 = [remap_dict[diacritic_index - 1], 1, 1] else: head_3 = [remap_dict[diacritic_index], 0, 1] diacritic_index = diacritic_index+8 return char_idx, diacritic_index, head_3 def create_label_for_word(split_word: List[List[str]]): word_char_indices = [] word_diac_indices = [] word_diac_indices_h3 = [] for char_w_diac in split_word: char_idx, diac_idx, diac_h3 = create_labels(char_w_diac) if char_idx == None: print(split_word) raise ValueError(char_idx) word_char_indices.append(char_idx) word_diac_indices.append(diac_idx) word_diac_indices_h3.append(diac_h3) return word_char_indices, word_diac_indices, word_diac_indices_h3 def flat_2_3head(output: T.Tensor): ''' output: [b tw tc] ''' haraka, tanween, shadda = [], [], [] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 # 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD b, ts, tw = output.shape for b_idx in range(b): h_s, t_s, s_s = [], [], [] for w_idx in range(ts): h_w, t_w, s_w = [], [], [] for c_idx in range(tw): c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])] h_w += [c[0]] t_w += [c[1]] s_w += [c[2]] h_s += [h_w] t_s += [t_w] s_s += [s_w] haraka += [h_s] tanween += [t_s] shadda += [s_s] return haraka, tanween, shadda def flat2_3head(diac_idx): ''' diac_idx: [tw] ''' haraka, tanween, shadda = [], [], [] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 # 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD for diac in diac_idx: c_out = HARAKAT_MAP[diac] haraka += [c_out[0]] tanween += [c_out[1]] shadda += [c_out[2]] return np.array(haraka), np.array(tanween), np.array(shadda)