Spaces:
Runtime error
Runtime error
from typing import List | |
import torch as T | |
import numpy as np | |
from pyarabic.araby import ( | |
tokenize, | |
strip_tashkeel, | |
strip_tatweel, | |
DIACRITICS | |
) | |
SEPARATE_DIACRITICS = { | |
"FATHA": 1, | |
"KASRA": 2, | |
"DAMMA": 3, | |
"SUKUN": 4 | |
} | |
HARAKAT_MAP = [ | |
#^ (haraka, tanween, shadda) | |
(0,0,0), #< No diacs on char | |
(1,0,0), | |
(1,1,0), #< Tanween on 2nd slot | |
(2,0,0), | |
(2,1,0), | |
(3,0,0), | |
(3,1,0), | |
(4,0,0), | |
(0,0,1), #< shadda on 3rd slot | |
(1,0,1), | |
(1,1,1), | |
(2,0,1), | |
(2,1,1), | |
(3,0,1), | |
(3,1,1), | |
(0,0,0), #< Padding == -1 (also for spaces) | |
] | |
DIAC_PAD_IDX = -1 | |
SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>'] | |
LETTER_LIST = SPECIAL_TOKENS + list("ุกุขุฃุคุฅุฆุงุจุฉุชุซุฌุญุฎุฏุฐุฑุฒุณุดุตุถุทุธุนุบููููู ููููู") | |
CLASSES_LIST = [' ', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ูู', 'ูู', 'ูู', 'ูู', 'ูู', 'ูู'] | |
DIACRITICS_SHORT = [' ', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู', 'ู'] | |
NUMBERS = list("0123456789") | |
DELIMITERS = ["ุ","ุ",",",";","ยซ","ยป","{","}","(",")","[","]",".","*","-",":","?","!","ุ"] | |
UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT))) | |
def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str: | |
returned_text = "" | |
if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]: | |
returned_text += "\u0651" | |
if diac == SEPARATE_DIACRITICS["FATHA"]: | |
returned_text += "\u064E" if not tanween else "\u064B" | |
elif diac == SEPARATE_DIACRITICS["KASRA"]: | |
returned_text += "\u0650" if not tanween else "\u064D" | |
elif diac == SEPARATE_DIACRITICS["DAMMA"]: | |
returned_text += "\u064F" if not tanween else "\u064C" | |
elif diac == SEPARATE_DIACRITICS["SUKUN"]: | |
returned_text += "\u0652" | |
return returned_text | |
def diac_ids_of_line(line: str): | |
diacs = [] | |
words = tokenize(line) | |
for word in words: | |
word_chars = split_word_on_characters_with_diacritics(word) | |
_cx, cy, _cy_3head = create_label_for_word(word_chars) | |
diacs.extend(cy) | |
diacs.append(DIAC_PAD_IDX) | |
return np.array(diacs[:-1]) | |
def strip_unknown_tashkeel(word: str): | |
#! FIXME! warnings.warn("Stripping unknown tashkeel is disabled.") | |
return word | |
return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS) | |
def create_gt_labels(lines): | |
gt_labels = [] | |
for line in lines: | |
# gt_labels_line = [] | |
# tokens = tokenize(line.strip()) | |
# for w_idx, word in enumerate(tokens): | |
# split_word = self.split_word_on_characters_with_diacritics(word) | |
# _, cy_flat, _ = du.create_label_for_word(split_word) | |
# gt_labels_line.extend(cy_flat) | |
# if w_idx+1 < len(tokens): | |
# gt_labels_line += [0] | |
gt_labels_line = diac_ids_of_line(line) | |
gt_labels.append(gt_labels_line) | |
return gt_labels | |
def split_word_on_characters_with_diacritics(word: str): | |
''' | |
TODO! Make faster without deque and looping | |
Returns: List[List[char: "letter or diacritic"]] | |
''' | |
chars_w_diac = [] | |
i_start = 0 | |
for i_c, c in enumerate(word): | |
#! FIXME! DIACRITICS_SHORT is missing a lot of less common diacritics ... | |
#! which are then treated as letters during splitting. | |
# if c not in DIACRITICS: | |
if c not in DIACRITICS_SHORT: | |
sub = list(word[i_start:i_c]) | |
chars_w_diac.append(sub) | |
i_start = i_c | |
sub = list(word[i_start:]) | |
if sub: | |
chars_w_diac.append(sub) | |
if not chars_w_diac[0]: | |
chars_w_diac = chars_w_diac[1:] | |
return chars_w_diac | |
def load_lines(path: str, *, strip: bool): | |
with open(path, 'r', encoding="utf-8", newline='\n') as fin: | |
if strip: | |
original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()] | |
else: | |
original_lines = [normalize_spaces(line) for line in fin.readlines()] | |
return original_lines | |
def normalize_spaces(line: str): | |
return ' '.join(tokenize(line.strip())) | |
def char_type(char: str): | |
if char in LETTER_LIST: | |
return LETTER_LIST.index(char) | |
elif char in NUMBERS: | |
return LETTER_LIST.index('<num>') | |
elif char in DELIMITERS: | |
return LETTER_LIST.index('<punc>') | |
else: | |
return LETTER_LIST.index('<unk>') | |
def create_labels(char_w_diac: str): | |
remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4} | |
char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:])) | |
if len(char_w_diac) > 3: | |
char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3] | |
char_idx = None | |
diacritic_index = None | |
head_3 = None | |
char_idx = char_type(char_w_diac[0]) | |
diacs = set(char_w_diac[1:]) | |
diac_h3 = [0, 0, 0] | |
for diac in diacs: | |
if diac in DIACRITICS_SHORT: | |
diac_idx = DIACRITICS_SHORT.index(diac) | |
if diac_idx in [2, 4, 6]: #< Tanween | |
diac_h3[0] = remap_dict[diac_idx - 1] | |
diac_h3[1] = 1 | |
elif diac_idx == 8: #< shadda | |
diac_h3[2] = 1 | |
else: #< Haraka or sukoon | |
diac_h3[0] = remap_dict[diac_idx] | |
assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2])) | |
diacritic_index = HARAKAT_MAP.index(tuple(diac_h3)) | |
return char_idx, diacritic_index, diac_h3 | |
if len(char_w_diac) == 1: | |
return char_idx, 0, [remap_dict[0], 0, 0] | |
elif len(char_w_diac) == 2: # If shadda OR diac | |
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) | |
if diacritic_index in [2, 4, 6]: # list of tanween | |
head_3 = [remap_dict[diacritic_index - 1], 1, 0] | |
elif diacritic_index == 8: | |
head_3 = [0, 0, 1] | |
else: | |
head_3 = [remap_dict[diacritic_index], 0, 0] | |
elif len(char_w_diac) == 3: # If shadda AND diac | |
if DIACRITICS_SHORT[8] == char_w_diac[1]: | |
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2]) | |
else: | |
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1]) | |
if diacritic_index in [2, 4, 6]: # list of tanween | |
head_3 = [remap_dict[diacritic_index - 1], 1, 1] | |
else: | |
head_3 = [remap_dict[diacritic_index], 0, 1] | |
diacritic_index = diacritic_index+8 | |
return char_idx, diacritic_index, head_3 | |
def create_label_for_word(split_word: List[List[str]]): | |
word_char_indices = [] | |
word_diac_indices = [] | |
word_diac_indices_h3 = [] | |
for char_w_diac in split_word: | |
char_idx, diac_idx, diac_h3 = create_labels(char_w_diac) | |
if char_idx == None: | |
print(split_word) | |
raise ValueError(char_idx) | |
word_char_indices.append(char_idx) | |
word_diac_indices.append(diac_idx) | |
word_diac_indices_h3.append(diac_h3) | |
return word_char_indices, word_diac_indices, word_diac_indices_h3 | |
def flat_2_3head(output: T.Tensor): | |
''' | |
output: [b tw tc] | |
''' | |
haraka, tanween, shadda = [], [], [] | |
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 | |
# 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD | |
b, ts, tw = output.shape | |
for b_idx in range(b): | |
h_s, t_s, s_s = [], [], [] | |
for w_idx in range(ts): | |
h_w, t_w, s_w = [], [], [] | |
for c_idx in range(tw): | |
c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])] | |
h_w += [c[0]] | |
t_w += [c[1]] | |
s_w += [c[2]] | |
h_s += [h_w] | |
t_s += [t_w] | |
s_s += [s_w] | |
haraka += [h_s] | |
tanween += [t_s] | |
shadda += [s_s] | |
return haraka, tanween, shadda | |
def flat2_3head(diac_idx): | |
''' | |
diac_idx: [tw] | |
''' | |
haraka, tanween, shadda = [], [], [] | |
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 | |
# 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD | |
for diac in diac_idx: | |
c_out = HARAKAT_MAP[diac] | |
haraka += [c_out[0]] | |
tanween += [c_out[1]] | |
shadda += [c_out[2]] | |
return np.array(haraka), np.array(tanween), np.array(shadda) |