File size: 1,661 Bytes
8a7cb66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import json

DATA_PATH = "groundtruths/gaia-2023-validation.jsonl"

class Verifier:
    def __init__(self, data_path: str | None=None):
        if data_path is None:
            data_path = DATA_PATH
        self.data_path = data_path
        self.data: dict = self.load_data()
        self.correct_cnt = 0
        self.total_cnt = 0

    def load_data(self) -> dict:
        data = {}
        with open(self.data_path, "r") as f:
            for line in f:
                record = json.loads(line)
                data[record["task_id"]] = record
        return data

    def verify(self, task_id: str, answer: str) -> None | tuple[bool, str]:
        record = self.data.get(task_id)
        if not record:
            print(f"Task ID {task_id} not found.")
            return None
        self.total_cnt += 1
        if record["Final answer"] == answer:
            self.correct_cnt += 1
            return True, record["Final answer"]
        return False, record["Final answer"]

    def get_answer(self, task_id: str) -> str:
        record = self.data.get(task_id)
        if not record:
            print(f"Task ID {task_id} not found.")
            return ""
        return record["Final answer"]

    def get_accuracy(self) -> float:
        if self.total_cnt == 0:
            return 0.0
        return self.correct_cnt / self.total_cnt

    def get_output(self) -> str:
        return f"**Correct:** {self.correct_cnt}/{self.total_cnt} ({self.get_accuracy():.2%})"

if __name__ == "__main__":
    verifier = Verifier()
    print(f"Loaded {len(verifier.data)} records from {verifier.data_path}")
    for record in verifier.data:
        print(record)