|
""" |
|
Loading the diacritization dataset |
|
""" |
|
|
|
import os |
|
|
|
from diacritization_evaluation import util |
|
import pandas as pd |
|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from .config_manager import ConfigManager |
|
|
|
BASIC_HARAQAT = { |
|
"ู": "Fatha ", |
|
"ู": "Fathatah ", |
|
"ู": "Damma ", |
|
"ู": "Dammatan ", |
|
"ู": "Kasra ", |
|
"ู": "Kasratan ", |
|
"ู": "Sukun ", |
|
"ู": "Shaddah ", |
|
} |
|
|
|
|
|
class DiacritizationDataset(Dataset): |
|
""" |
|
The diacritization dataset |
|
""" |
|
|
|
def __init__(self, config_manager: ConfigManager, list_ids, data): |
|
"Initialization" |
|
self.list_ids = list_ids |
|
self.data = data |
|
self.text_encoder = config_manager.text_encoder |
|
self.config = config_manager.config |
|
|
|
def __len__(self): |
|
"Denotes the total number of samples" |
|
return len(self.list_ids) |
|
|
|
def preprocess(self, book): |
|
out = "" |
|
i = 0 |
|
while i < len(book): |
|
if i < len(book) - 1: |
|
if book[i] in BASIC_HARAQAT and book[i + 1] in BASIC_HARAQAT: |
|
i += 1 |
|
continue |
|
out += book[i] |
|
i += 1 |
|
return out |
|
|
|
def __getitem__(self, index): |
|
"Generates one sample of data" |
|
|
|
id = self.list_ids[index] |
|
if self.config["is_data_preprocessed"]: |
|
data = self.data.iloc[id] |
|
inputs = torch.Tensor(self.text_encoder.input_to_sequence(data[1])) |
|
targets = torch.Tensor( |
|
self.text_encoder.target_to_sequence( |
|
data[2].split(self.config["diacritics_separator"]) |
|
) |
|
) |
|
return inputs, targets, data[0] |
|
|
|
data = self.data[id] |
|
non_cleaned = data |
|
|
|
data = self.text_encoder.clean(data) |
|
data = data[: self.config["max_sen_len"]] |
|
text, inputs, diacritics = util.extract_haraqat(data) |
|
|
|
inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs))) |
|
diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics)) |
|
|
|
return inputs, diacritics, text |
|
|
|
|
|
def collate_fn(data): |
|
""" |
|
Padding the input and output sequences |
|
""" |
|
|
|
def merge(sequences): |
|
lengths = [len(seq) for seq in sequences] |
|
padded_seqs = torch.zeros(len(sequences), max(lengths)).long() |
|
for i, seq in enumerate(sequences): |
|
end = lengths[i] |
|
padded_seqs[i, :end] = seq[:end] |
|
return padded_seqs, lengths |
|
|
|
data.sort(key=lambda x: len(x[0]), reverse=True) |
|
|
|
|
|
src_seqs, trg_seqs, original = zip(*data) |
|
|
|
|
|
src_seqs, src_lengths = merge(src_seqs) |
|
trg_seqs, trg_lengths = merge(trg_seqs) |
|
|
|
batch = { |
|
"original": original, |
|
"src": src_seqs, |
|
"target": trg_seqs, |
|
"lengths": torch.LongTensor(src_lengths), |
|
} |
|
return batch |
|
|
|
|
|
def load_training_data(config_manager: ConfigManager, loader_parameters): |
|
""" |
|
Loading the training data using pandas |
|
""" |
|
|
|
if not config_manager.config["load_training_data"]: |
|
return [] |
|
|
|
path = os.path.join(config_manager.data_dir, "train.csv") |
|
if config_manager.config["is_data_preprocessed"]: |
|
train_data = pd.read_csv( |
|
path, |
|
encoding="utf-8", |
|
sep=config_manager.config["data_separator"], |
|
nrows=config_manager.config["n_training_examples"], |
|
header=None, |
|
) |
|
|
|
|
|
training_set = DiacritizationDataset( |
|
config_manager, train_data.index, train_data |
|
) |
|
else: |
|
with open(path, encoding="utf8") as file: |
|
train_data = file.readlines() |
|
train_data = [ |
|
text |
|
for text in train_data |
|
if len(text) <= config_manager.config["max_len"] and len(text) > 0 |
|
] |
|
training_set = DiacritizationDataset( |
|
config_manager, [idx for idx in range(len(train_data))], train_data |
|
) |
|
|
|
train_iterator = DataLoader( |
|
training_set, collate_fn=collate_fn, **loader_parameters |
|
) |
|
|
|
print(f"Length of training iterator = {len(train_iterator)}") |
|
return train_iterator |
|
|
|
|
|
def load_test_data(config_manager: ConfigManager, loader_parameters): |
|
""" |
|
Loading the test data using pandas |
|
""" |
|
if not config_manager.config["load_test_data"]: |
|
return [] |
|
test_file_name = config_manager.config.get("test_file_name", "test.csv") |
|
path = os.path.join(config_manager.data_dir, test_file_name) |
|
if config_manager.config["is_data_preprocessed"]: |
|
test_data = pd.read_csv( |
|
path, |
|
encoding="utf-8", |
|
sep=config_manager.config["data_separator"], |
|
nrows=config_manager.config["n_test_examples"], |
|
header=None, |
|
) |
|
|
|
test_dataset = DiacritizationDataset(config_manager, test_data.index, test_data) |
|
else: |
|
with open(path, encoding="utf8") as file: |
|
test_data = file.readlines() |
|
max_len = config_manager.config["max_len"] |
|
test_data = [text[:max_len] for text in test_data] |
|
test_dataset = DiacritizationDataset( |
|
config_manager, [idx for idx in range(len(test_data))], test_data |
|
) |
|
|
|
test_iterator = DataLoader(test_dataset, collate_fn=collate_fn, **loader_parameters) |
|
|
|
print(f"Length of test iterator = {len(test_iterator)}") |
|
return test_iterator |
|
|
|
|
|
def load_validation_data(config_manager: ConfigManager, loader_parameters): |
|
""" |
|
Loading the validation data using pandas |
|
""" |
|
|
|
if not config_manager.config["load_validation_data"]: |
|
return [] |
|
path = os.path.join(config_manager.data_dir, "eval.csv") |
|
if config_manager.config["is_data_preprocessed"]: |
|
valid_data = pd.read_csv( |
|
path, |
|
encoding="utf-8", |
|
sep=config_manager.config["data_separator"], |
|
nrows=config_manager.config["n_validation_examples"], |
|
header=None, |
|
) |
|
valid_data = valid_data[valid_data[0] <= config_manager.config["max_len"]] |
|
valid_dataset = DiacritizationDataset( |
|
config_manager, valid_data.index, valid_data |
|
) |
|
else: |
|
with open(path, encoding="utf8") as file: |
|
valid_data = file.readlines() |
|
|
|
max_len = config_manager.config["max_len"] |
|
valid_data = [text[:max_len] for text in valid_data] |
|
valid_dataset = DiacritizationDataset( |
|
config_manager, [idx for idx in range(len(valid_data))], valid_data |
|
) |
|
|
|
valid_iterator = DataLoader( |
|
valid_dataset, collate_fn=collate_fn, **loader_parameters |
|
) |
|
|
|
print(f"Length of valid iterator = {len(valid_iterator)}") |
|
return valid_iterator |
|
|
|
|
|
def load_iterators(config_manager: ConfigManager): |
|
""" |
|
Load the data iterators |
|
Args: |
|
""" |
|
params = { |
|
"batch_size": config_manager.config["batch_size"], |
|
"shuffle": True, |
|
"num_workers": 2, |
|
} |
|
train_iterator = load_training_data(config_manager, loader_parameters=params) |
|
valid_iterator = load_validation_data(config_manager, loader_parameters=params) |
|
test_iterator = load_test_data(config_manager, loader_parameters=params) |
|
return train_iterator, test_iterator, valid_iterator |
|
|