Partial-Arabic-Diacritization / dataloader_plm.py
bkhmsi's picture
support for TD2
d7c4b94
from typing import List, Tuple, Any
import os
from functools import lru_cache
from pyarabic.araby import tokenize, strip_tashkeel
import numpy as np
import torch as T
from torch.utils.data import Dataset
try:
from transformers import PreTrainedTokenizer
except:
from typing import Any as PreTrainedTokenizer
from data_utils import DatasetUtils
import diac_utils as du
class DataRetriever(Dataset):
def __init__(
self,
lines,
data_utils: DatasetUtils,
is_test: bool = False,
*,
tokenizer: PreTrainedTokenizer,
lines_mode: bool = False,
**kwargs,
):
super(DataRetriever).__init__()
self.data_utils = data_utils
self.is_test = is_test
self.tokenizer = tokenizer
self.stride = data_utils.test_stride
self.data_points = lines
self.bos_token_id = int(self.tokenizer.bos_token_id or self.tokenizer.cls_token_id)
self.eos_token_id = int(self.tokenizer.eos_token_id or self.tokenizer.sep_token_id)
self.max_tokens = self.data_utils.max_token_count
self.max_slen = self.data_utils.max_sent_len
self.max_wlen = self.data_utils.max_word_len
# self.p_val = self.data_utils.pad_val
self.p_val = self.tokenizer.pad_token_id
self.pc_val = self.data_utils.pad_char_id
self.pt_val = self.data_utils.pad_target_val
self.char_x_padding = [self.pc_val] * self.max_wlen
self.diac_x_padding = [[self.pc_val]*8] * self.max_wlen
self.diac_y_padding = [self.pt_val] * self.max_wlen
def preprocess(self, data, dtype=T.long):
return [T.tensor(np.array(x), dtype=dtype) for x in data]
def __len__(self):
return len(self.data_points)
@lru_cache(maxsize=1024 * 2)
def __getitem__(self, idx: int) -> Tuple[List[T.Tensor], T.Tensor, T.Tensor]:
word_x, char_x, diac_x, diac_y, subword_lengths = self.create_sentence(idx)
return (
self.preprocess([word_x, char_x, diac_x]),
T.tensor(diac_y, dtype=T.long),
T.tensor(subword_lengths, dtype=T.long)
)
def create_sentence(self, idx):
line = self.data_points[idx]
# tokens = tokenize(line.strip())
words: List[str] = tokenize(line.strip())
# words_: List[str] = []
# for word in words:
# if len(strip_tashkeel(word)) == 0:
# words_[-1] += word.strip()
# else:
# words_.append(word)
# word_tokens_bin = [self.tokenizer(word) for word in words]
# tokens_bin = self.tokenizer(line.strip())
subwords_x = [self.bos_token_id]
subword_lengths = []
char_x = []
diac_x = []
diac_y = []
diac_y_tmp = []
for i_word, word in enumerate(words):
word = du.strip_unknown_tashkeel(word)
word_chars = du.split_word_on_characters_with_diacritics(word)
cx, cy, cy_3head = du.create_label_for_word(word_chars)
word_strip = strip_tashkeel(word)
#? List[int: "word_index"]
#? Strip the BOS/EOS which the tokenizer adds
word_sub_ids = self.tokenizer(word_strip)['input_ids'][1:-1]
subword_lengths += [len(word_sub_ids)]
subwords_x += word_sub_ids
# word_x += [self.data_utils.w2idx.get(word_strip, self.data_utils.w2idx["<pad>"])]
char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.max_wlen)]
diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.max_wlen, pad=self.data_utils.pad_target_val)]
diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.max_wlen, pad=[self.data_utils.pad_target_val]*3)]
assert len(char_x) == len(subword_lengths), f"{char_x=}; {subword_lengths=} ;;"
assert len(char_x) == len(words)
diac_x = self.data_utils.create_decoder_input(diac_y_tmp)
subwords_x += [self.eos_token_id]
# assert len(char_x) + 2 == len(subwords_x), f"{len(char_x)} + 2 != {len(subwords_x)} ;;" # Because of BOS, EOS
assert len(subword_lengths) == len(words)
subwords_x = self.data_utils.pad_and_truncate_sequence(subwords_x, self.max_tokens, pad=self.p_val)
subword_lengths = self.data_utils.pad_and_truncate_sequence(subword_lengths, self.max_slen, pad=0)
char_x = self.data_utils.pad_and_truncate_sequence(char_x, self.max_slen, pad=self.char_x_padding)
diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, self.max_slen, pad=self.diac_x_padding)
diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, self.max_slen, pad=self.diac_y_padding)
return subwords_x, char_x, diac_x, diac_y, subword_lengths