File size: 12,433 Bytes
8b26094
 
3d98ed0
8b26094
 
 
3d98ed0
 
8b26094
 
 
 
 
3d98ed0
 
8b26094
3d98ed0
 
 
8b26094
 
 
3d98ed0
 
 
 
 
 
 
 
 
 
8b26094
3d98ed0
 
 
 
 
 
8b26094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe139f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b26094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d98ed0
8b26094
e2522a1
 
 
 
 
 
 
8b26094
3d98ed0
 
 
 
 
 
 
 
 
 
 
8b26094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d98ed0
8b26094
 
3d98ed0
 
 
5fe139f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b26094
 
3d98ed0
 
 
 
 
 
 
 
 
 
 
8b26094
3d98ed0
 
8b26094
 
 
 
3d98ed0
 
 
 
e6a1854
3d98ed0
5fe139f
 
 
 
 
 
 
 
 
 
3d98ed0
 
 
 
 
 
 
 
 
 
 
 
 
8b26094
3d98ed0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
from typing import List, Dict, TypedDict, Annotated, Callable, Optional, Any
from huggingface_hub.inference._generated.types import question_answering
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage
#from langchain_community.chat_models import ChatHuggingFace
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from answering import gen_question_answer

import requests
import os
from dotenv import load_dotenv
load_dotenv()


class AgentState(TypedDict):
    context: Dict[str, str]
    question: str
    answer: Annotated[str, lambda x, y: y]  # Overwrite with new value
    formatted_answer: Annotated[str, lambda x, y: y]
    format_requirement: str    
    uris: List[str]
    reasoning: List[str]

class GAIAAnsweringWorkflow:
    def __init__(
        self,
        qa_function: Optional[Callable[[str], str]] = None,
        formatter: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize the GAIA agent workflow
        
        Args: 
            qa_function: Core question answering function (gen_question_answer)
            formatter: Answer formatting function (default: GAIA boxed format)
        """
        self.qa_function = gen_question_answer #qa_function or self.default_qa_function
        self.formatter = formatter or self.default_formatter
        self.workflow = self.build_workflow()
        # Initialize model with HF Inference API
        llm_endpoint = HuggingFaceEndpoint(
            model="deepseek-ai/DeepSeek-R1",#endpoint_url="https://api-inference.huggingface.co/models/cortexso/deepseek-r1:7b",
            huggingfacehub_api_token=os.getenv("HF_TOKEN"),
            task="text-generation",
            #max_tokens=1024
        )
        self.reasoning_llm  = ChatHuggingFace(llm=llm_endpoint)

        self.llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                model="mistralai/Mistral-7B-Instruct-v0.3",
                huggingfacehub_api_token=os.getenv("HF_TOKEN"),
                task="text-generation",                
            )

        )



    def ask_llm(self, question: str, do_reasoning=False)->str:
        prompt = question
        messages = [HumanMessage(content=prompt)]
        response = self.llm.invoke(messages) 
        answer = str(response.content)      
        return answer


    def extract_noted_urls_with_llm(self, question: str) -> List[str]:
        """Use LLM to extract URLs specifically noted in the question"""
        prompt = f"""
        Analyze the following question and extract ONLY URLs that are explicitly noted or referenced.
        Return each URL on a separate line. If no URLs are noted, return an empty string.
        
        QUESTION: {question}
        
        Respond ONLY with the URLs, one per line, with no additional text or formatting.
        """
        
        try:
            # Use your LLM to generate the response
            response = self.ask_llm(prompt)
            
            # Parse the response to extract URLs
            urls = []
            for line in response.split('\n'):
                line = line.strip()
                #if line.startswith(('http://', 'https://', 'www.')):
                urls.append(line)
            
            return urls
        except Exception as e:
            print(f"LLM-based URL extraction failed: {str(e)}")
            return []

    def download_file(self, task_id: str, file_name: str) -> str:
        """Download file from API and return local path"""
        try:
            #os.makedirs("files", exist_ok=True)
            file_path = f"./{file_name}"#files/{file_id}"
            api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
            api_endpoint = f"{api_base_url}/files/{task_id}"
            response = requests.get(api_endpoint)
            response.raise_for_status()
            
            with open(file_path, "wb") as f:
                f.write(response.content)
                
            print(f"File saved: {file_path}")
            return file_path
        except Exception as e:
            print(f"File download failed: {str(e)}")
            return ""

    
    def check_context_independent(self, state: AgentState)->bool:
        if ctx := state.get("context"):
            if ctx.get("filename"): 
                return False
        prompt = f"""
        
        I have a CodeAgent based on the text-to-text model that can use Internet search and parse the information found. 
        If this approach is enough to successfully cope with the task, then we will call such a task an "easy question"

        AS AN ERUDITE PERSON  YOU must analyze how difficult it will be to solve the next question 
        <<{state["question"]}>>
        
        If you think that the question is easy, then return an empty string. Important! You should NOT add any symbols to the output in this case!
        If the question concerns the use of additional resources such as complex analysis of downloaded files or resources on the Internet, then return an action plan

        """
        reply = self.ask_llm(prompt, True)
        prompt = f""" The reasonings from other LLM is provided: <<{reply}>>
        You have to Summarize: 
        output either empty string ('') for easy question 
        or extract action plan for non-easy question
        """
        reply = self.ask_llm(prompt, False)
        if reply:
            state["reasoning"].append(reply)
            return False
        return True         

    def preparations_node(self, state: AgentState) -> dict:        
        if not state["context"]:
            return {}
        """Node to prepare resources"""
        context = state["context"]
        question = state["question"]
        uris = state["uris"]
        
        # 1. Handle file_id in context
        if file_name:= context.get("file_name"):
            file_path = self.download_file(context["task_id"], file_name)
            if file_path:
                uris.append(file_path)
        
        # 2. Extract URLs from question
        found_urls = self.extract_noted_urls_with_llm(question)
        if found_urls:
            uris.extend(found_urls)
            print(f"Added {len(found_urls)} URL(s) from question")
        
        return {"uris": uris}
        

    def triage_node(self, state: AgentState) -> dict:
        return {}
    
    def deep_processing_node(self, state: AgentState) -> dict:
        question = f"""
        question: \n <<{state["question"]}>> \n 
        resources: {str(state["uris"])} \n 
        reasoning: <<{state["reasoning"]}>>
        """
        answer = gen_question_answer(question)
        return {state["answer"]: answer}    

    def generate_answer_node(self, state: AgentState) -> dict:
        """Node that executes the question answering tool"""
        try:
            answer = self.qa_function(state["question"])
            return {"answer": answer}
        except Exception as e:
            print(str(e))
            return {"answer": f"Error: {str(e)}"}
    
    def format_output_node(self, state: AgentState) -> dict:
        """Node that formats the answer for GAIA benchmark"""
        prompt = f"""
        
        As a very smart person, you should formulate what should be the output format of the answer to the question:
        <<{state["question"]}>>

        You must formulate it very briefly and clearly!
        The common requirements is:
        << 
    OUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.

    If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise, and don't include additional text. 
    If the answer is a number, represent it with digits.
    If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. 
    If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
    >>
        But you have to figure out how the answer should looks like in the given case and reformulate requirement according to the specified question

        """
        format_requirement = self.ask_llm(prompt, True)

        prompt = f"""
Your attentiveness and responsibility are very much needed! We are solving a strict test that is automatically checked, so we must formulate the answer in strict accordance with the task and the required format! Even one incorrect symbol in the answer can fail the task! Pull yourself together!

You will be required to produce an output of the answer, but formatted in accordance with the task

Received answer: <<{state['answer']}>>

Format requirements: <<{format_requirement}>>   

Do NOT include << >> in your answer! Don't use full answer formulations! If you are asked about a number it MUST be just a number, nothing more! Each time it should be a clear answer (checked automatically) 
        """ 
        try:
            formatted = self.ask_llm(prompt)
            return {"formatted_answer": formatted.strip()}
        except Exception as e:
            return {"formatted_answer": f"\\boxed{{\\text{{Formatting error: {str(e)}}}}}"}
    
    def build_workflow(self) -> Any:
        """Construct and compile the LangGraph workflow"""
        # Create graph
        workflow = StateGraph(AgentState)
        
        # Add nodes
        workflow.add_node("preparations", self.preparations_node)
        workflow.add_node("triage", self.triage_node)
        workflow.add_node("deep_processing", self.deep_processing_node)
        workflow.add_node("generate_answer", self.generate_answer_node)
        workflow.add_node("format_output", self.format_output_node)
        
        # Define edges
        workflow.set_entry_point("preparations")
        workflow.add_edge("preparations", "triage")
        workflow.add_conditional_edges("triage"
                                        , self.check_context_independent
                                        , {
                                            True: "generate_answer",
                                            False: "deep_processing"
                                        })
        workflow.add_edge("deep_processing", "format_output")

        workflow.add_edge("generate_answer", "format_output")
        workflow.add_edge("format_output", END)
        
        return workflow.compile()


    
    def __call__(self, question: str, context: Dict|None=None) -> str:
        """
        Execute the agent workflow for a given question
        
        Args:
            question: Input question string
            
        Returns:
            Formatted GAIA answer
        """
        # Initialize state
        initial_state = {
            "context": context,
            "question": question,
            "answer": "",
            "formatted_answer": "",
            "format_requirement": "",    
            "uris": [],
            "reasoning": []            
        }
        
        # Execute workflow
        result = self.workflow.invoke(initial_state)
        return result["answer"]#["formatted_answer"]

    @staticmethod
    def default_qa_function(question: str) -> str:
        """Placeholder QA function (override with your CodeAgent)"""
        return "42"
    
    @staticmethod
    def default_formatter(answer: str) -> str:
        """Default GAIA formatting"""
        return answer #f"\\boxed{{{answer}}}"

# Example usage with custom QA function
if __name__ == "__main__":
    # Custom QA function (replace with your CodeAgent integration)
    def custom_qa(question: str) -> str:
        if "life" in question:
            return "42"
        elif "prime" in question:
            return "101"
        return "unknown"
    
    # Create agent instance
    agent = GAIAAnsweringWorkflow(
        qa_function=gen_question_answer,
        formatter=lambda ans: ans #f"ANSWER: \\boxed{{{ans}}}"  # Custom formatting
    )
    
    # Test cases
    questions = [
        "What is the answer to life, the universe, and everything?",
        "What is the smallest 3-digit prime number?",
        "Unknown question type?"
    ]
    
    for q in questions:
        result = agent(q)
        print(f"Question: {q}\nAnswer: {result}\n")