abstract-generator / app /abstract_dataset.py
cahya's picture
add first commmit
9242ed0
import torch
import random
from torch.utils.data import Dataset
class AbstractDataset(Dataset):
special_tokens = {"bos_token": "<|BOS|>",
"eos_token": "<|EOS|>",
"unk_token": "<|UNK|>",
"pad_token": "<|PAD|>",
"sep_token": "<|SEP|>"}
max_length = 1024
def __init__(self, data, tokenizer, randomize=True):
title, text, keywords = [], [], []
for k, v in data.items():
title.append(v[0])
text.append(v[1])
keywords.append(v[2])
self.randomize = randomize
self.tokenizer = tokenizer
self.title = title
self.text = text
self.keywords = keywords
@staticmethod
def join_keywords(keywords, randomize=True):
N = len(keywords)
# random sampling and shuffle
if randomize:
# M = random.choice(range(N + 1))
# keywords = keywords[:M]
random.shuffle(keywords)
return ','.join(keywords)
def __len__(self):
return len(self.text)
def __getitem__(self, i):
keywords = self.keywords[i].copy()
kw = self.join_keywords(keywords, self.randomize)
input = self.special_tokens['bos_token'] + self.title[i] + \
self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \
self.text[i] + self.special_tokens['eos_token']
encodings_dict = self.tokenizer(input,
truncation=True,
max_length=self.max_length,
padding="max_length")
input_ids = encodings_dict['input_ids']
attention_mask = encodings_dict['attention_mask']
return {'label': torch.tensor(input_ids),
'input_ids': torch.tensor(input_ids),
'attention_mask': torch.tensor(attention_mask)}