File size: 1,958 Bytes
9242ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import random
from torch.utils.data import Dataset


class AbstractDataset(Dataset):
    special_tokens = {"bos_token": "<|BOS|>",
                      "eos_token": "<|EOS|>",
                      "unk_token": "<|UNK|>",
                      "pad_token": "<|PAD|>",
                      "sep_token": "<|SEP|>"}
    max_length = 1024

    def __init__(self, data, tokenizer, randomize=True):
        title, text, keywords = [], [], []
        for k, v in data.items():
            title.append(v[0])
            text.append(v[1])
            keywords.append(v[2])

        self.randomize = randomize
        self.tokenizer = tokenizer
        self.title = title
        self.text = text
        self.keywords = keywords

    @staticmethod
    def join_keywords(keywords, randomize=True):
        N = len(keywords)

        # random sampling and shuffle
        if randomize:
            # M = random.choice(range(N + 1))
            # keywords = keywords[:M]
            random.shuffle(keywords)

        return ','.join(keywords)


    def __len__(self):
        return len(self.text)


    def __getitem__(self, i):
        keywords = self.keywords[i].copy()
        kw = self.join_keywords(keywords, self.randomize)

        input = self.special_tokens['bos_token'] + self.title[i] + \
                self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \
                self.text[i] + self.special_tokens['eos_token']

        encodings_dict = self.tokenizer(input,
                                   truncation=True,
                                   max_length=self.max_length,
                                   padding="max_length")

        input_ids = encodings_dict['input_ids']
        attention_mask = encodings_dict['attention_mask']

        return {'label': torch.tensor(input_ids),
                'input_ids': torch.tensor(input_ids),
                'attention_mask': torch.tensor(attention_mask)}