File size: 1,417 Bytes
6c25ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os 
import json 

import torch 
from torch.utils.data import DataLoader, Dataset
print("data.py")
def my_collate(batch):
    '''
    'doc_key': ex['doc_key'],
    'input_token_ids':input_tokens['input_ids'],
    'input_attn_mask': input_tokens['attention_mask'],
    'tgt_token_ids': tgt_tokens['input_ids'],
    'tgt_attn_mask': tgt_tokens['attention_mask'],
    '''
    doc_keys = [ex['doc_key'] for ex in batch]
    input_token_ids = torch.stack([torch.LongTensor(ex['input_token_ids']) for ex in batch]) 
    input_attn_mask = torch.stack([torch.BoolTensor(ex['input_attn_mask']) for ex in batch])
    tgt_token_ids = torch.stack([torch.LongTensor(ex['tgt_token_ids']) for ex in batch]) 
    tgt_attn_mask = torch.stack([torch.BoolTensor(ex['tgt_attn_mask']) for ex in batch])

    return {
        'input_token_ids': input_token_ids,
        'input_attn_mask': input_attn_mask,
        'tgt_token_ids': tgt_token_ids,
        'tgt_attn_mask': tgt_attn_mask,
        'doc_key': doc_keys,
    }


class IEDataset(Dataset):
    def __init__(self, input_file):
        super().__init__()
        self.examples = []
        with open(input_file, 'r') as f:
            for line in f:
                ex = json.loads(line.strip())
                self.examples.append(ex)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]