File size: 4,433 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import copy
import os.path as osp

import pandas as pd
from typing import Union
import torch

from stark_qa.tools.download_hf import download_hf_folder


STARK_QA_DATASET = {
    "repo": "snap-stanford/stark",
    "folder": "qa"
}

class STaRKDataset:
    def __init__(self, 
                 name: str, 
                 root: Union[str, None] = None, 
                 human_generated_eval: bool = False):
        """
        Initialize the STaRK dataset.

        Args:
            name (str): Name of the dataset.
            root (Union[str, None]): Root directory to store the dataset. If None, default HF cache paths will be used.
            human_generated_eval (bool): Whether to use human-generated evaluation data.
        """
        self.name = name
        self.root = root
        self.dataset_root = osp.join(self.root, name) if self.root is not None else None
        self._download()
        self.split_dir = osp.join(self.dataset_root, 'split')
        self.query_dir = osp.join(self.dataset_root, 'stark_qa')
        self.human_generated_eval = human_generated_eval

        self.qa_csv_path = osp.join(
            self.query_dir, 
            'stark_qa_human_generated_eval.csv' if human_generated_eval else 'stark_qa.csv'
        )
        
        self.data = pd.read_csv(self.qa_csv_path)
        self.indices = sorted(self.data['id'].tolist())
        self.split_indices = self.get_idx_split()

    def __len__(self) -> int:
        """
        Return the number of queries in the dataset.

        Returns:
            int: Number of queries.
        """
        return len(self.indices)

    def __getitem__(self, idx: int):
        """
        Get the query, id, answer ids, and meta information for a given index.

        Args:
            idx (int): Index of the query.

        Returns:
            tuple: Query, query id, answer ids, and meta information.
        """
        q_id = self.indices[idx]
        row = self.data[self.data['id'] == q_id].iloc[0]
        query = row['query']
        answer_ids = eval(row['answer_ids'])
        meta_info = None  # Replace with actual meta information if available
        return query, q_id, answer_ids, meta_info

    def _download(self):
        """
        Download the dataset from the Hugging Face repository.
        """
        self.dataset_root = download_hf_folder(
            STARK_QA_DATASET["repo"],
            osp.join(STARK_QA_DATASET["folder"], self.name),
            repo_type="dataset",
            save_as_folder=self.dataset_root,
        )

    def get_idx_split(self, test_ratio: float = 1.0) -> dict:
        """
        Return the indices of train/val/test split in a dictionary.

        Args:
            test_ratio (float): Ratio of test data to include.

        Returns:
            dict: Dictionary with split indices for train, val, and test sets.
        """
        if self.human_generated_eval:
            return {'human_generated_eval': torch.LongTensor(self.indices)}

        split_idx = {}
        for split in ['train', 'val', 'test']:
            indices_file = osp.join(self.split_dir, f'{split}.index')
            with open(indices_file, 'r') as f:
                indices = f.read().strip().split('\n')
            query_ids = [int(idx) for idx in indices]
            split_idx[split] = torch.LongTensor([self.indices.index(query_id) for query_id in query_ids])

        if test_ratio < 1.0:
            split_idx['test'] = split_idx['test'][:int(len(split_idx['test']) * test_ratio)]
        return split_idx

    def get_query_by_qid(self, q_id: int) -> str:
        """
        Return the query by query id.

        Args:
            q_id (int): Query id.

        Returns:
            str: Query string.
        """
        row = self.data[self.data['id'] == q_id].iloc[0]
        return row['query']

    def get_subset(self, split: str):
        """
        Return a subset of the dataset.

        Args:
            split (str): Split type ('train', 'val', 'test').

        Returns:
            STaRKDataset: Subset of the dataset.
        """
        assert split in ['train', 'val', 'test'], "Invalid split specified."
        indices_file = osp.join(self.split_dir, f'{split}.index')
        with open(indices_file, 'r') as f:
            indices = f.read().strip().split('\n')
        subset = copy.deepcopy(self)
        subset.indices = [int(idx) for idx in indices]
        return subset