File size: 2,520 Bytes
0c3992e
 
 
 
 
 
 
a00d62c
0c3992e
 
a00d62c
 
 
 
 
 
0c3992e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00d62c
 
 
0c3992e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import copy
import os.path as osp
import torch
import pandas as pd


class STaRKDataset:
    def __init__(self, query_dir, split_dir, human_generated_eval=False):
        self.query_dir = query_dir
        self.split_dir = split_dir
        self.human_generated_eval = human_generated_eval
        if human_generated_eval:
            self.qa_csv_path = osp.join(query_dir, 'stark_qa_human_generated_eval.csv')
        else:
            self.qa_csv_path = osp.join(query_dir, 'stark_qa.csv')
        print('Loading QA dataset from', self.qa_csv_path)
        self.data = pd.read_csv(self.qa_csv_path)

        self.indices = list(self.data['id'])
        self.indices.sort()
        self.split_indices = self.get_idx_split()
    
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        q_id = self.indices[idx]
        meta_info = None
        row = self.data[self.data['id'] == q_id].iloc[0]
        query = row['query']
        answer_ids = eval(row['answer_ids'])
        
        return query, q_id, answer_ids, meta_info

    def get_idx_split(self, test_ratio=1.0):
        '''
        Return the indices of train/val/test split in a dictionary.
        '''
        if self.human_generated_eval:
            return {'human_generated_eval': torch.LongTensor(self.indices)}
        
        split_idx = {}
        for split in ['train', 'val', 'test']:
            # `{split}.index`stores query ids, not the index in the dataset
            indices_file = osp.join(self.split_dir, f'{split}.index') 
            indices = open(indices_file, 'r').read().strip().split('\n')
            query_ids = [int(idx) for idx in indices]
            split_idx[split] = torch.LongTensor([self.indices.index(query_id) for query_id in query_ids])
        if test_ratio < 1.0:
            split_idx['test'] = split_idx['test'][:int(len(split_idx['test']) * test_ratio)]
        return split_idx

    def get_query_by_qid(self, q_id):
        '''
        Return the query by query id.
        '''
        row = self.data[self.data['id'] == q_id].iloc[0]
        return row['query']
        
    def get_subset(self, split):
        '''
        Return a subset of the dataset.
        '''
        assert split in ['train', 'val', 'test']
        indices_file = osp.join(self.split_dir, f'{split}.index') 
        indices = open(indices_file, 'r').read().strip().split('\n')
        subset = copy.deepcopy(self)
        subset.indices = [int(idx) for idx in indices]
        return subset