lwant commited on
Commit
88b51a3
Β·
1 Parent(s): fa2bac9

Make `run_and_submit_all` asynchronous, update imports, and add telemetry initialization

Browse files
Files changed (1) hide show
  1. src/gaia_solving_agent/agent.py +112 -8
src/gaia_solving_agent/agent.py CHANGED
@@ -1,8 +1,112 @@
1
- class BasicAgent:
2
- def __init__(self):
3
- print("BasicAgent initialized.")
4
- def __call__(self, question: str) -> str:
5
- print(f"Agent received question (first 50 chars): {question[:50]}...")
6
- fixed_answer = "This is a default answer."
7
- print(f"Agent returning fixed answer: {fixed_answer}")
8
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from llama_index.core.agent.workflow import FunctionAgent, AgentWorkflow
4
+ from llama_index.core.prompts import RichPromptTemplate
5
+ from llama_index.llms.nebius import NebiusLLM
6
+ from llama_index.tools.requests import RequestsToolSpec
7
+ from workflows import Workflow, step
8
+ from workflows.events import StartEvent, Event, StopEvent
9
+
10
+ from gaia_solving_agent import NEBIUS_API_KEY
11
+ from gaia_solving_agent.prompts import PLANING_PROMPT, FORMAT_ANSWER
12
+ from gaia_solving_agent.tools import tavily_search_web
13
+
14
+ # Choice of the model
15
+ model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
+ # model_name = "deepseek-ai/DeepSeek-R1-0528"
17
+
18
+ def get_llm(model_name=model_name):
19
+ return NebiusLLM(
20
+ model=model_name,
21
+ api_key=NEBIUS_API_KEY,
22
+ is_function_calling_model=True,
23
+ max_completion_tokens=10000,
24
+ context_window=80000, # max = 128000 for "meta-llama/Meta-Llama-3.1-8B-Instruct"
25
+ temperature=.1,
26
+ max_retries=5,
27
+ )
28
+
29
+
30
+ class QueryEvent(Event):
31
+ query: str
32
+ plan: str
33
+
34
+ class AnswerEvent(Event):
35
+ plan: str
36
+ answer: str
37
+
38
+
39
+ class GaiaWorkflow(Workflow):
40
+ @step
41
+ async def setup(self, ev: StartEvent) -> QueryEvent:
42
+ llm = get_llm()
43
+ prompt_template = RichPromptTemplate(
44
+ PLANING_PROMPT,
45
+ template_var_mappings={"query": "user_request"},
46
+ )
47
+ plan = llm.complete(prompt_template.format(query=ev.query))
48
+ return QueryEvent(query=ev.query, plan=plan.text)
49
+
50
+ @step()
51
+ async def multi_agent_process(self, ev: QueryEvent) -> AnswerEvent:
52
+ # Cheap trick to avoid Error 400 errors from OpenAPI
53
+ from llama_index.core.memory import ChatMemoryBuffer
54
+ memory = ChatMemoryBuffer.from_defaults(token_limit=100000)
55
+
56
+ agent_output = await gaia_solving_agent.run(user_msg=ev.plan, memory=memory)
57
+ return AnswerEvent(plan=ev.plan, answer=str(agent_output))
58
+
59
+ @step
60
+ async def parse_answer(self, ev: AnswerEvent) -> StopEvent:
61
+ llm = get_llm()
62
+ prompt_template = RichPromptTemplate(FORMAT_ANSWER)
63
+ pattern = r"Question :\s*(.*)[\n$]"
64
+ search = re.search(pattern, ev.plan)
65
+ question = search.group(1) if search else ""
66
+ result = llm.complete(prompt_template.format(question=question))
67
+ return StopEvent(result=result)
68
+
69
+
70
+ tavily_search_engine = FunctionAgent(
71
+ tools=[tavily_search_web],
72
+ llm=get_llm(),
73
+ system_prompt="""
74
+ You are a helpful assistant that does web searches.
75
+ Convert the user need into one or multiple web searches.
76
+ Each web search should aim for one specific topic.
77
+ A topic is defined as one to few words.
78
+ If the user needs to search for multiple topics, make multiple searches.
79
+ """,
80
+ name="search_engine_agent",
81
+ can_handoff_to = ["visit_web_page_agent"],
82
+ description="Agent that makes web searches to answer questions."
83
+ )
84
+
85
+ visit_website = FunctionAgent(
86
+ tools=[
87
+ *RequestsToolSpec().to_tool_list(),
88
+ ],
89
+ llm=get_llm(),
90
+ system_prompt="""
91
+ You are a helpful assistant that visit a website.
92
+ Given a url, you should visit the web page and return a summary of the page.
93
+ The summary should answer the concerns of the user.
94
+
95
+ If the url is invalid, return "Invalid URL".
96
+ If the url is not a web page, return "Not a web page".
97
+ If the url is not reachable, return "Not reachable".
98
+ """,
99
+ name="visit_web_page_agent",
100
+ description="Agent that visit a web page and return a summary of the page."
101
+ )
102
+
103
+
104
+ gaia_solving_agent = AgentWorkflow(
105
+ agents = [tavily_search_engine, visit_website],
106
+ initial_state = dict(),
107
+ root_agent = tavily_search_engine.name,
108
+ handoff_prompt = None,
109
+ handoff_output_prompt = None,
110
+ state_prompt = None,
111
+ num_concurrent_runs=1,
112
+ )