Spaces:
Runtime error
Runtime error
File size: 1,958 Bytes
9242ed0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import random
from torch.utils.data import Dataset
class AbstractDataset(Dataset):
special_tokens = {"bos_token": "<|BOS|>",
"eos_token": "<|EOS|>",
"unk_token": "<|UNK|>",
"pad_token": "<|PAD|>",
"sep_token": "<|SEP|>"}
max_length = 1024
def __init__(self, data, tokenizer, randomize=True):
title, text, keywords = [], [], []
for k, v in data.items():
title.append(v[0])
text.append(v[1])
keywords.append(v[2])
self.randomize = randomize
self.tokenizer = tokenizer
self.title = title
self.text = text
self.keywords = keywords
@staticmethod
def join_keywords(keywords, randomize=True):
N = len(keywords)
# random sampling and shuffle
if randomize:
# M = random.choice(range(N + 1))
# keywords = keywords[:M]
random.shuffle(keywords)
return ','.join(keywords)
def __len__(self):
return len(self.text)
def __getitem__(self, i):
keywords = self.keywords[i].copy()
kw = self.join_keywords(keywords, self.randomize)
input = self.special_tokens['bos_token'] + self.title[i] + \
self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \
self.text[i] + self.special_tokens['eos_token']
encodings_dict = self.tokenizer(input,
truncation=True,
max_length=self.max_length,
padding="max_length")
input_ids = encodings_dict['input_ids']
attention_mask = encodings_dict['attention_mask']
return {'label': torch.tensor(input_ids),
'input_ids': torch.tensor(input_ids),
'attention_mask': torch.tensor(attention_mask)}
|