Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import numpy as np | |
from tqdm import tqdm | |
from prettytable import PrettyTable | |
from pyarabic.araby import tokenize, strip_tashkeel | |
import diac_utils as du | |
class DatasetUtils: | |
def __init__(self, config): | |
self.base_path = config["paths"]["base"] | |
self.special_tokens = ['<pad>', '<unk>', '<num>', '<punc>'] | |
self.delimeters = config["sentence-break"]["delimeters"] | |
self.load_constants(config["paths"]["constants"]) | |
self.debug = config["debug"] | |
self.stride = config["sentence-break"]["stride"] | |
self.window = config["sentence-break"]["window"] | |
self.val_stride = config["sentence-break"].get("val-stride", self.stride) | |
self.test_stride = config["predictor"]["stride"] | |
self.test_window = config["predictor"]["window"] | |
self.max_word_len = config["train"]["max-word-len"] | |
self.max_sent_len = config["train"]["max-sent-len"] | |
self.max_token_count = config["train"]["max-token-count"] | |
self.pad_target_val = -100 | |
self.pad_char_id = du.DIAC_PAD_IDX #LETTER_LIST.index('<pad>') | |
self.markov_signal = config['train'].get('markov-signal', False) | |
self.batch_first = config['train'].get('batch-first', True) | |
self.gt_prob = config["predictor"]["gt-signal-prob"] | |
if self.gt_prob > 0: | |
self.s_idx = config["predictor"]["seed-idx"] | |
subpath = f"test_gt_mask_{self.gt_prob}_{self.s_idx}.txt" | |
mask_path = os.path.join(self.base_path, "test", subpath) | |
with open(mask_path, 'r') as fin: | |
self.gt_mask = fin.readlines() | |
if "word-embs" in config["paths"] and config["paths"]["word-embs"].strip() != "": | |
self.pad_val = self.special_tokens.index("<pad>") | |
self.embeddings, self.vocab = self.load_embeddings(config["paths"]["word-embs"], config["loader"]["wembs-limit"]) | |
self.embeddings = self.normalize(self.embeddings, ["unit", "centeremb", "unit"]) | |
self.w2idx = {word: i for i, word in enumerate(self.vocab)} | |
def load_file(self, path): | |
with open(path, 'rb') as f: | |
return list(pickle.load(f)) | |
def normalize(self, matrix, actions, mean=None): | |
def length_normalize(matrix): | |
norms = np.sqrt(np.sum(matrix**2, axis=1)) | |
norms[norms == 0] = 1 | |
matrix = matrix / norms[:, np.newaxis] | |
return matrix | |
def mean_center(matrix): | |
return matrix - mean | |
def length_normalize_dimensionwise(matrix): | |
norms = np.sqrt(np.sum(matrix**2, axis=0)) | |
norms[norms == 0] = 1 | |
matrix = matrix / norms | |
return matrix | |
def mean_center_embeddingwise(matrix): | |
avg = np.mean(matrix, axis=1) | |
matrix = matrix - avg[:, np.newaxis] | |
return matrix | |
for action in actions: | |
if action == 'unit': | |
matrix = length_normalize(matrix) | |
elif action == 'center': | |
matrix = mean_center(matrix) | |
elif action == 'unitdim': | |
matrix = length_normalize_dimensionwise(matrix) | |
elif action == 'centeremb': | |
matrix = mean_center_embeddingwise(matrix) | |
return matrix | |
def load_constants(self, path): | |
# self.numbers = [c for c in "0123456789"] | |
# self.letter_list = self.special_tokens + self.load_file(os.path.join(path, 'ARABIC_LETTERS_LIST.pickle')) | |
# self.diacritic_list = [' '] + self.load_file(os.path.join(path, 'DIACRITICS_LIST.pickle')) | |
self.numbers = du.NUMBERS | |
self.letter_list = du.LETTER_LIST | |
self.diacritic_list = du.DIACRITICS_SHORT | |
def split_word_on_characters_with_diacritics(self, word: str): | |
return du.split_word_on_characters_with_diacritics(word) | |
def load_mapping_v3(self, dtype, file_ext=None): | |
mapping = {} | |
if file_ext is None: | |
file_ext = f"-{self.test_stride}-{self.test_window}.map" | |
f_name = os.path.join(self.base_path, dtype, dtype + file_ext) | |
with open(f_name, 'r') as fin: | |
for line in fin: | |
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) | |
if sent_idx not in mapping: | |
mapping[sent_idx] = {} | |
if seg_idx not in mapping[sent_idx]: | |
mapping[sent_idx][seg_idx] = {} | |
if t_idx not in mapping[sent_idx][seg_idx]: | |
mapping[sent_idx][seg_idx][t_idx] = [] | |
mapping[sent_idx][seg_idx][t_idx] += [c_idx] | |
return mapping | |
def load_mapping_v3_from_list(self, mapping_list): | |
mapping = {} | |
for line in mapping_list: | |
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(',')) | |
if sent_idx not in mapping: | |
mapping[sent_idx] = {} | |
if seg_idx not in mapping[sent_idx]: | |
mapping[sent_idx][seg_idx] = {} | |
if t_idx not in mapping[sent_idx][seg_idx]: | |
mapping[sent_idx][seg_idx][t_idx] = [] | |
mapping[sent_idx][seg_idx][t_idx] += [c_idx] | |
return mapping | |
def load_embeddings(self, embs_path, limit=-1): | |
if self.debug: | |
return np.zeros((200+len(self.special_tokens),300)), self.special_tokens + ["c"] * 200 | |
words = [self.special_tokens[0]] | |
print(f"[INFO] Reading Embeddings from {embs_path}") | |
with open(embs_path, encoding='utf-8', mode='r') as fin: | |
n, d = map(int, fin.readline().split()) | |
limit = n if limit <= 0 else limit | |
embeddings = np.zeros((limit+1, d)) | |
for i, line in tqdm(enumerate(fin), total=limit): | |
if i >= limit: break | |
tokens = line.rstrip().split() | |
words += [tokens[0]] | |
embeddings[i+1] = list(map(float, tokens[1:])) | |
return embeddings, words | |
def load_file_clean(self, dtype, strip=False): | |
f_name = os.path.join(self.base_path, dtype, dtype + ".txt") | |
with open(f_name, 'r', encoding="utf-8", newline='\n') as fin: | |
if strip: | |
original_lines = [strip_tashkeel(self.preprocess(line)) for line in fin.readlines()] | |
else: | |
original_lines = [self.preprocess(line) for line in fin.readlines()] | |
return original_lines | |
def preprocess(self, line): | |
return ' '.join(tokenize(line)) | |
def pad_and_truncate_sequence(self, tokens, max_len, pad=None): | |
if pad is None: | |
pad = self.special_tokens.index("<pad>") | |
if len(tokens) < max_len: | |
offset = max_len - len(tokens) | |
return tokens + [pad] * offset | |
else: | |
return tokens[:max_len] | |
def stats(self, freq, percentile=90, name="stats"): | |
table = PrettyTable(["Dataset", "Mean", "Std", "Min", "Max", f"{percentile}th Percentile"]) | |
freq = np.array(sorted(freq)) | |
table.add_row([name, freq.mean(), freq.std(), freq.min(), freq.max(), np.percentile(freq, percentile)]) | |
print(table) | |
def create_gt_mask(self, lines, prob, idx, seed=1111): | |
np.random.seed(seed) | |
gt_masks = [] | |
for line in lines: | |
tokens = tokenize(line.strip()) | |
gt_mask_token = "" | |
for t_idx, token in enumerate(tokens): | |
gt_mask_token += ''.join(map(str, np.random.binomial(1, prob, len(token)))) | |
if t_idx+1 < len(tokens): | |
gt_mask_token += " " | |
gt_masks += [gt_mask_token] | |
subpath = f"test_gt_mask_{prob}_{idx}.txt" | |
mask_path = os.path.join(self.base_path, "test", subpath) | |
with open(mask_path, 'w') as fout: | |
fout.write('\n'.join(gt_masks)) | |
def create_gt_labels(self, lines): | |
gt_labels = [] | |
for line in lines: | |
gt_labels_line = [] | |
tokens = tokenize(line.strip()) | |
for w_idx, word in enumerate(tokens): | |
split_word = self.split_word_on_characters_with_diacritics(word) | |
_, cy_flat, _ = du.create_label_for_word(split_word) | |
gt_labels_line.extend(cy_flat) | |
if w_idx+1 < len(tokens): | |
gt_labels_line += [0] | |
gt_labels += [gt_labels_line] | |
return gt_labels | |
def get_ce(self, diac_word_y, e_idx=None, return_idx=False): | |
#^ diac_word_y: [Tw 3] | |
if e_idx is None: e_idx = len(diac_word_y) | |
for c_idx in reversed(range(e_idx)): | |
if diac_word_y[c_idx] != [0,0,0]: | |
return diac_word_y[c_idx] if not return_idx else c_idx | |
return diac_word_y[e_idx-1] if not return_idx else e_idx-1 | |
def create_decoder_input(self, diac_code_y, prob=0): | |
#^ diac_code_y: [Ts Tw 3] | |
diac_code_x = np.zeros((*np.array(diac_code_y).shape[:-1], 8)) | |
if not self.markov_signal: | |
return list(diac_code_x) | |
prev_ce = list(np.eye(6)[-1]) + [0,0] # bos tag | |
for w_idx, word in enumerate(diac_code_y): | |
diac_code_x[w_idx, 0, :] = prev_ce | |
for c_idx, char in enumerate(word[:-1]): | |
# if np.random.rand() < prob: | |
# continue | |
if char[0] == self.pad_target_val: | |
break | |
haraka = list(np.eye(6)[char[0]]) | |
diac_code_x[w_idx, c_idx+1, :] = haraka + char[1:] | |
ce = self.get_ce(diac_code_y[w_idx], c_idx) | |
prev_ce = list(np.eye(6)[ce[0]]) + ce[1:] | |
return list(diac_code_x) |