import torch import random from torch.utils.data import Dataset class AbstractDataset(Dataset): special_tokens = {"bos_token": "<|BOS|>", "eos_token": "<|EOS|>", "unk_token": "<|UNK|>", "pad_token": "<|PAD|>", "sep_token": "<|SEP|>"} max_length = 1024 def __init__(self, data, tokenizer, randomize=True): title, text, keywords = [], [], [] for k, v in data.items(): title.append(v[0]) text.append(v[1]) keywords.append(v[2]) self.randomize = randomize self.tokenizer = tokenizer self.title = title self.text = text self.keywords = keywords @staticmethod def join_keywords(keywords, randomize=True): N = len(keywords) # random sampling and shuffle if randomize: # M = random.choice(range(N + 1)) # keywords = keywords[:M] random.shuffle(keywords) return ','.join(keywords) def __len__(self): return len(self.text) def __getitem__(self, i): keywords = self.keywords[i].copy() kw = self.join_keywords(keywords, self.randomize) input = self.special_tokens['bos_token'] + self.title[i] + \ self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \ self.text[i] + self.special_tokens['eos_token'] encodings_dict = self.tokenizer(input, truncation=True, max_length=self.max_length, padding="max_length") input_ids = encodings_dict['input_ids'] attention_mask = encodings_dict['attention_mask'] return {'label': torch.tensor(input_ids), 'input_ids': torch.tensor(input_ids), 'attention_mask': torch.tensor(attention_mask)}