from typing import List, Tuple, Any import os from functools import lru_cache from pyarabic.araby import tokenize, strip_tashkeel import numpy as np import torch as T from torch.utils.data import Dataset try: from transformers import PreTrainedTokenizer except: from typing import Any as PreTrainedTokenizer from data_utils import DatasetUtils import diac_utils as du class DataRetriever(Dataset): def __init__( self, lines, data_utils: DatasetUtils, is_test: bool = False, *, tokenizer: PreTrainedTokenizer, lines_mode: bool = False, **kwargs, ): super(DataRetriever).__init__() self.data_utils = data_utils self.is_test = is_test self.tokenizer = tokenizer self.stride = data_utils.test_stride self.data_points = lines self.bos_token_id = int(self.tokenizer.bos_token_id or self.tokenizer.cls_token_id) self.eos_token_id = int(self.tokenizer.eos_token_id or self.tokenizer.sep_token_id) self.max_tokens = self.data_utils.max_token_count self.max_slen = self.data_utils.max_sent_len self.max_wlen = self.data_utils.max_word_len # self.p_val = self.data_utils.pad_val self.p_val = self.tokenizer.pad_token_id self.pc_val = self.data_utils.pad_char_id self.pt_val = self.data_utils.pad_target_val self.char_x_padding = [self.pc_val] * self.max_wlen self.diac_x_padding = [[self.pc_val]*8] * self.max_wlen self.diac_y_padding = [self.pt_val] * self.max_wlen def preprocess(self, data, dtype=T.long): return [T.tensor(np.array(x), dtype=dtype) for x in data] def __len__(self): return len(self.data_points) @lru_cache(maxsize=1024 * 2) def __getitem__(self, idx: int) -> Tuple[List[T.Tensor], T.Tensor, T.Tensor]: word_x, char_x, diac_x, diac_y, subword_lengths = self.create_sentence(idx) return ( self.preprocess([word_x, char_x, diac_x]), T.tensor(diac_y, dtype=T.long), T.tensor(subword_lengths, dtype=T.long) ) def create_sentence(self, idx): line = self.data_points[idx] # tokens = tokenize(line.strip()) words: List[str] = tokenize(line.strip()) # words_: List[str] = [] # for word in words: # if len(strip_tashkeel(word)) == 0: # words_[-1] += word.strip() # else: # words_.append(word) # word_tokens_bin = [self.tokenizer(word) for word in words] # tokens_bin = self.tokenizer(line.strip()) subwords_x = [self.bos_token_id] subword_lengths = [] char_x = [] diac_x = [] diac_y = [] diac_y_tmp = [] for i_word, word in enumerate(words): word = du.strip_unknown_tashkeel(word) word_chars = du.split_word_on_characters_with_diacritics(word) cx, cy, cy_3head = du.create_label_for_word(word_chars) word_strip = strip_tashkeel(word) #? List[int: "word_index"] #? Strip the BOS/EOS which the tokenizer adds word_sub_ids = self.tokenizer(word_strip)['input_ids'][1:-1] subword_lengths += [len(word_sub_ids)] subwords_x += word_sub_ids # word_x += [self.data_utils.w2idx.get(word_strip, self.data_utils.w2idx[""])] char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.max_wlen)] diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.max_wlen, pad=self.data_utils.pad_target_val)] diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.max_wlen, pad=[self.data_utils.pad_target_val]*3)] assert len(char_x) == len(subword_lengths), f"{char_x=}; {subword_lengths=} ;;" assert len(char_x) == len(words) diac_x = self.data_utils.create_decoder_input(diac_y_tmp) subwords_x += [self.eos_token_id] # assert len(char_x) + 2 == len(subwords_x), f"{len(char_x)} + 2 != {len(subwords_x)} ;;" # Because of BOS, EOS assert len(subword_lengths) == len(words) subwords_x = self.data_utils.pad_and_truncate_sequence(subwords_x, self.max_tokens, pad=self.p_val) subword_lengths = self.data_utils.pad_and_truncate_sequence(subword_lengths, self.max_slen, pad=0) char_x = self.data_utils.pad_and_truncate_sequence(char_x, self.max_slen, pad=self.char_x_padding) diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, self.max_slen, pad=self.diac_x_padding) diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, self.max_slen, pad=self.diac_y_padding) return subwords_x, char_x, diac_x, diac_y, subword_lengths