Spaces:
Runtime error
Runtime error
import torch | |
import random | |
from torch.utils.data import Dataset | |
class AbstractDataset(Dataset): | |
special_tokens = {"bos_token": "<|BOS|>", | |
"eos_token": "<|EOS|>", | |
"unk_token": "<|UNK|>", | |
"pad_token": "<|PAD|>", | |
"sep_token": "<|SEP|>"} | |
max_length = 1024 | |
def __init__(self, data, tokenizer, randomize=True): | |
title, text, keywords = [], [], [] | |
for k, v in data.items(): | |
title.append(v[0]) | |
text.append(v[1]) | |
keywords.append(v[2]) | |
self.randomize = randomize | |
self.tokenizer = tokenizer | |
self.title = title | |
self.text = text | |
self.keywords = keywords | |
def join_keywords(keywords, randomize=True): | |
N = len(keywords) | |
# random sampling and shuffle | |
if randomize: | |
# M = random.choice(range(N + 1)) | |
# keywords = keywords[:M] | |
random.shuffle(keywords) | |
return ','.join(keywords) | |
def __len__(self): | |
return len(self.text) | |
def __getitem__(self, i): | |
keywords = self.keywords[i].copy() | |
kw = self.join_keywords(keywords, self.randomize) | |
input = self.special_tokens['bos_token'] + self.title[i] + \ | |
self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \ | |
self.text[i] + self.special_tokens['eos_token'] | |
encodings_dict = self.tokenizer(input, | |
truncation=True, | |
max_length=self.max_length, | |
padding="max_length") | |
input_ids = encodings_dict['input_ids'] | |
attention_mask = encodings_dict['attention_mask'] | |
return {'label': torch.tensor(input_ids), | |
'input_ids': torch.tensor(input_ids), | |
'attention_mask': torch.tensor(attention_mask)} | |