AES-EXP-3 / AES_EXP_1_DATASET.py
jaytonde05's picture
Upload AES_EXP_1_DATASET.py with huggingface_hub
b3345c6 verified
import torch
def prepare_input(cfg, text, tokenizer):
"""
This function tokenizes the input text with the configured padding and truncation. Then,
returns the input dictionary, which contains the following keys: "input_ids",
"token_type_ids" and "attention_mask". Each value is a torch.tensor.
:param cfg: configuration class with a TOKENIZER attribute.
:param text: a numpy array where each value is a text as string.
:return inputs: python dictionary where values are torch tensors.
"""
inputs = tokenizer.encode_plus(
text,
return_tensors = None,
add_special_tokens = True,
max_length = cfg.MAX_LEN,
padding = 'max_length', # TODO: check padding to max sequence in batch
truncation = True
)
for k, v in inputs.items():
inputs[k] = torch.tensor(v, dtype=torch.long) # TODO: check dtypes
return inputs
def collate(inputs):
"""
It truncates the inputs to the maximum sequence length in the batch.
"""
mask_len = int(inputs["attention_mask"].sum(axis=1).max()) # Get batch's max sequence length
for k, v in inputs.items():
inputs[k] = inputs[k][:,:mask_len]
return inputs
class CustomDataset(Dataset):
def __init__(self, cfg, df, tokenizer):
self.cfg = cfg
self.texts = df['full_text'].values
self.labels = df['score'].values
self.tokenizer = tokenizer
self.essay_ids = df['essay_id'].values
def __len__(self):
return len(self.texts)
def __getitem__(self, item):
output = {}
output["inputs"] = prepare_input(self.cfg, self.texts[item], self.tokenizer)
output["labels"] = torch.tensor(self.labels[item], dtype=torch.long) # TODO: check dtypes
output["essay_ids"] = self.essay_ids[item]
return output