File size: 7,089 Bytes
d195d4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import asyncio
import copy
import pdb

from factool.knowledge_qa.pipeline import knowledge_qa_pipeline
from factool.code.pipeline import code_pipeline
from factool.math.pipeline import math_pipeline
from factool.scientific.pipeline import scientific_pipeline

class Factool():
    def __init__(self, foundation_model):
        self.foundation_model = foundation_model
        self.pipelines = {
                            "kbqa_online": knowledge_qa_pipeline(
                                foundation_model, 10, "online"
                            ),
                            "code": code_pipeline(
                                foundation_model, 3, 3
                            ),
                            "math": math_pipeline(
                                foundation_model
                            ),
                            "scientific": scientific_pipeline(
                                foundation_model
                            ),
                        }

    def run(self, inputs):
        outputs = copy.deepcopy(inputs)
        batches = []
        current_category = inputs[0]['category']
        current_search_type = inputs[0].get('search_type', None)
        current_data_link = inputs[0].get('data_link', None)
        current_embedding_link = inputs[0].get('embedding_link', None)
        current_batch = []

        for input in inputs:
            if (input['category'] == current_category != 'kbqa') \
                or (input['category'] == current_category == 'kbqa' and input.get('search_type', None) == current_search_type == "online") \
                    or (input['category'] == current_category == 'kbqa' and input.get('search_type', None) == current_search_type == "local"\
                        and input.get('data_link', None)==current_data_link and input.get('embedding_link', None)==current_embedding_link):
                current_batch.append(input)
            else:
                batches.append(current_batch)
                current_batch = [input]
                current_category = input['category']
                current_search_type = input.get('search_type', None)
                current_data_link = input.get('data_link', None)
                current_embedding_link = input.get('embedding_link', None)
                    
        batches.append(current_batch)  # append the last batch

        index = 0
        for batch in batches:
            if not batch: continue
            #pdb.set_trace()
            category = batch[0]['category']
            search_type = batch[0].get('search_type', None)
            if category == 'code':
                batch_results = asyncio.run(
                    self.pipelines[category].run_with_tool_api_call(
                        [sample['prompt'] for sample in batch],
                        [sample['response'] for sample in batch],
                        [sample['entry_point'] for sample in batch]
                    )
                )
            elif category == 'kbqa':
                if search_type is None or search_type == "online":
                    batch_results = asyncio.run(
                        self.pipelines[category+"_online"].run_with_tool_api_call(
                            [sample['prompt'] for sample in batch],
                            [sample['response'] for sample in batch],
                        )
                    )
                else:
                    batch_results = asyncio.run(
                        knowledge_qa_pipeline(
                            self.foundation_model,2,"local",batch[0].get("data_link"),batch[0].get("embedding_link")
                        ).run_with_tool_api_call(
                            [sample['prompt'] for sample in batch],
                            [sample['response'] for sample in batch],
                        )
                    )
            else:
                batch_results = asyncio.run(
                    self.pipelines[category].run_with_tool_api_call(
                        [sample['prompt'] for sample in batch],
                        [sample['response'] for sample in batch]
                    )
                )
            for result in batch_results:
                outputs[index].update(result)
                index += 1
        
        # calculate average response_level_factuality
        total_response_factuality = sum(output['response_level_factuality'] for output in outputs)
        avg_response_level_factuality = total_response_factuality / len(outputs)

        # calculate average claim_level_factuality
        num_claims = 0
        total_claim_factuality = 0
        for output in outputs:
            if output['category'] == 'kbqa':
                num_claims += len(output['claim_level_factuality'])
                total_claim_factuality += sum(claim['factuality'] for claim in output['claim_level_factuality'])
            elif output['category'] == 'code':
                num_claims += 1
                total_claim_factuality += output['claim_level_factuality']
            elif output['category'] == 'math':
                num_claims += len(output['claim_level_factuality'])
                total_claim_factuality += sum(output['claim_level_factuality'])
            elif output['category'] == 'scientific':
                num_claims += len(output['claim_level_factuality'])
                total_claim_factuality += sum(claim['factuality'] for claim in output['claim_level_factuality'])

        avg_claim_level_factuality = total_claim_factuality / num_claims

        return {"average_claim_level_factuality": avg_claim_level_factuality, "average_response_level_factuality": avg_response_level_factuality, "detailed_information": outputs}

    async def run_for_plugin(self, inputs):
        outputs = copy.deepcopy(inputs)

        batches = []
        current_category = inputs[0]['category']
        current_batch = []

        for input in inputs:
            if input['category'] == current_category:
                current_batch.append(input)
            else:
                batches.append(current_batch)
                current_batch = [input]
                current_category = input['category']
                    
        batches.append(current_batch)  # append the last batch

        index = 0
        for batch in batches:
            category = batch[0]['category']
            if category == 'code':
                batch_results = await self.pipelines[category].run_with_tool_api_call(
                    [sample['prompt'] for sample in batch],
                    [sample['response'] for sample in batch],
                    [sample['entry_point'] for sample in batch],
                )
            else:
                batch_results = await self.pipelines[category].run_with_tool_api_call(
                    [sample['prompt'] for sample in batch],
                    [sample['response'] for sample in batch],
                )
            for result in batch_results:
                outputs[index].update(result)
                index += 1
        
        return outputs