3Represents commited on
Commit
951d8ce
1 Parent(s): 131901a
Files changed (2) hide show
  1. ReAct.py +411 -0
  2. ReAct.yaml +50 -0
ReAct.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import time
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import hydra
6
+ from pydantic import root_validator
7
+
8
+ from langchain import LLMChain, PromptTemplate
9
+ from langchain.agents import AgentExecutor, BaseMultiActionAgent, ZeroShotAgent
10
+ from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
11
+ from langchain.chat_models import ChatOpenAI
12
+ from langchain.schema import (
13
+ AgentAction,
14
+ AgentFinish,
15
+ OutputParserException,
16
+ )
17
+
18
+ from flows.base_flows import Flow, CompositeFlow, GenericLCTool
19
+ from flows.messages import OutputMessage, UpdateMessage_Generic
20
+ from flows.utils.caching_utils import flow_run_cache
21
+
22
+
23
+ class GenericZeroShotAgent(ZeroShotAgent):
24
+ @classmethod
25
+ def create_prompt(
26
+ cls,
27
+ tools: Dict[str, Flow],
28
+ prefix: str = PREFIX,
29
+ suffix: str = SUFFIX,
30
+ format_instructions: str = FORMAT_INSTRUCTIONS,
31
+ input_variables: Optional[List[str]] = None,
32
+ ) -> PromptTemplate:
33
+ """Create prompt in the style of the zero shot agent.
34
+
35
+ Args:
36
+ tools: List of tools the agent will have access to, used to format the
37
+ prompt.
38
+ prefix: String to put before the list of tools.
39
+ suffix: String to put after the list of tools.
40
+ input_variables: List of input variables the final prompt will expect.
41
+
42
+ Returns:
43
+ A PromptTemplate with the template assembled from the pieces here.
44
+ """
45
+ # tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
46
+ # tool_names = ", ".join([tool.name for tool in tools])
47
+ tool_strings = "\n".join([f"{tool_name}: {tool.flow_config['description']}" for tool_name, tool in tools.items()])
48
+ tool_names = ", ".join(tools.keys())
49
+ format_instructions = format_instructions.format(tool_names=tool_names)
50
+ template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
51
+ if input_variables is None:
52
+ input_variables = ["input", "agent_scratchpad"]
53
+ return PromptTemplate(template=template, input_variables=input_variables)
54
+
55
+
56
+ class GenericAgentExecutor(AgentExecutor):
57
+ tools: Dict[str, Flow]
58
+
59
+ @root_validator()
60
+ def validate_tools(cls, values: Dict) -> Dict:
61
+ """Validate that tools are compatible with agent."""
62
+ agent = values["agent"]
63
+ tools = values["tools"]
64
+ allowed_tools = agent.get_allowed_tools()
65
+ if allowed_tools is not None:
66
+ if set(allowed_tools) != set(tools.keys()):
67
+ raise ValueError(
68
+ f"Allowed tools ({allowed_tools}) different than "
69
+ f"provided tools ({tools.keys()})"
70
+ )
71
+ return values
72
+
73
+ @root_validator()
74
+ def validate_return_direct_tool(cls, values: Dict) -> Dict:
75
+ """Validate that tools are compatible with agent."""
76
+ agent = values["agent"]
77
+ tools = values["tools"]
78
+ if isinstance(agent, BaseMultiActionAgent):
79
+ for tool in tools:
80
+ if tool.flow_config["return_direct"]:
81
+ raise ValueError(
82
+ "Tools that have `return_direct=True` are not allowed "
83
+ "in multi-action agents"
84
+ )
85
+ return values
86
+
87
+ def _get_tool_return(
88
+ self, next_step_output: Tuple[AgentAction, str]
89
+ ) -> Optional[AgentFinish]:
90
+ """Check if the tool is a returning tool."""
91
+ agent_action, observation = next_step_output
92
+ # name_to_tool_map = {tool.name: tool for tool in self.tools}
93
+ # Invalid tools won't be in the map, so we return False.
94
+ if agent_action.tool in self.tools:
95
+ if self.tools[agent_action.tool].flow_config["return_direct"]:
96
+ return AgentFinish(
97
+ {self.agent.return_values[0]: observation},
98
+ "",
99
+ )
100
+ return None
101
+
102
+
103
+ class ReActFlow(CompositeFlow):
104
+ EXCEPTION_FLOW_CONFIG = {
105
+ "_target_": "flows.base_flows.GenericLCTool.instantiate_from_config",
106
+ "config": {
107
+ "name": "_Exception",
108
+ "description": "Exception tool",
109
+
110
+ "tool_type": "exception",
111
+ "input_keys": ["query"],
112
+ "output_keys": ["raw_response"],
113
+
114
+ "verbose": False,
115
+ "clear_flow_namespace_on_run_end": False,
116
+
117
+ "input_data_transformations": [],
118
+ "output_data_transformations": [],
119
+ "keep_raw_response": True
120
+ }
121
+ }
122
+
123
+ INVALID_FLOW_CONFIG = {
124
+ "_target_": "flows.base_flows.GenericLCTool.instantiate_from_config",
125
+ "config": {
126
+ "name": "invalid_tool",
127
+ "description": "Called when tool name is invalid.",
128
+
129
+ "tool_type": "invalid",
130
+ "input_keys": ["tool_name"],
131
+ "output_keys": ["raw_response"],
132
+
133
+ "verbose": False,
134
+ "clear_flow_namespace_on_run_end": False,
135
+
136
+ "input_data_transformations": [],
137
+ "output_data_transformations": [],
138
+ "keep_raw_response": True
139
+ }
140
+ }
141
+
142
+ SUPPORTS_CACHING: bool = True
143
+
144
+ api_keys: Dict[str, str]
145
+
146
+ backend: GenericAgentExecutor
147
+ react_prompt_template: PromptTemplate
148
+
149
+ exception_flow: GenericLCTool
150
+ invalid_flow: GenericLCTool
151
+
152
+ def __init__(self, **kwargs):
153
+ super().__init__(**kwargs)
154
+
155
+ self.api_keys = None
156
+ self.backend = None
157
+ self.react_prompt_template = GenericZeroShotAgent.create_prompt(
158
+ tools=self.subflows,
159
+ **self.flow_config.get("prompt_config", {})
160
+ )
161
+
162
+ self._set_up_necessary_subflows()
163
+
164
+ def set_up_flow_state(self):
165
+ super().set_up_flow_state()
166
+ self.flow_state["intermediate_steps"]: List[Tuple[AgentAction, str]] = []
167
+
168
+ def _set_up_necessary_subflows(self):
169
+ self.exception_flow = hydra.utils.instantiate(
170
+ self.EXCEPTION_FLOW_CONFIG, _convert_="partial", _recursive_=False
171
+ )
172
+ self.invalid_flow = hydra.utils.instantiate(
173
+ self.INVALID_FLOW_CONFIG, _convert_="partial", _recursive_=False
174
+ )
175
+
176
+ def _get_prompt_message(self, input_data: Dict[str, Any]) -> str:
177
+ data = copy.deepcopy(input_data)
178
+ data["agent_scratchpad"] = "{agent_scratchpad}" # dummy value for agent scratchpad
179
+
180
+ return self.react_prompt_template.format(**data)
181
+
182
+ @staticmethod
183
+ def get_raw_response(output: OutputMessage) -> str:
184
+ key = output.data["output_keys"][0]
185
+ return output.data["output_data"]["raw_response"][key]
186
+
187
+ def _take_next_step(
188
+ self,
189
+ # name_to_tool_map: Dict[str, BaseTool],
190
+ # color_mapping: Dict[str, str],
191
+ inputs: Dict[str, str],
192
+ intermediate_steps: List[Tuple[AgentAction, str]],
193
+ # run_manager: Optional[CallbackManagerForChainRun] = None,
194
+ # input_data: Dict[str, Any],
195
+ private_keys: Optional[List[str]] = [],
196
+ keys_to_ignore_for_hash: Optional[List[str]] = []
197
+ ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
198
+ """Take a single step in the thought-action-observation loop.
199
+
200
+ Override this to take control of how the agent makes and acts on choices.
201
+ """
202
+ try:
203
+ # Call the LLM to see what to do.
204
+ output = self.backend.agent.plan(
205
+ intermediate_steps,
206
+ # callbacks=run_manager.get_child() if run_manager else None,
207
+ **inputs,
208
+ )
209
+ except OutputParserException as e:
210
+ if isinstance(self.backend.handle_parsing_errors, bool):
211
+ raise_error = not self.backend.handle_parsing_errors
212
+ else:
213
+ raise_error = False
214
+ if raise_error:
215
+ raise e
216
+ text = str(e)
217
+
218
+ if isinstance(self.backend.handle_parsing_errors, bool):
219
+ if e.send_to_llm:
220
+ observation = str(e.observation)
221
+ text = str(e.llm_output)
222
+ else:
223
+ observation = "Invalid or incomplete response"
224
+ elif isinstance(self.backend.handle_parsing_errors, str):
225
+ observation = self.backend.handle_parsing_errors
226
+ elif callable(self.backend.handle_parsing_errors):
227
+ observation = self.backend.handle_parsing_errors(e)
228
+ else:
229
+ raise ValueError("Got unexpected type of `handle_parsing_errors`")
230
+
231
+ output = AgentAction("_Exception", observation, text)
232
+ # if run_manager:
233
+ # run_manager.on_agent_action(output, color="green")
234
+ # tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs()
235
+ # observation = ExceptionTool().run(
236
+ # output.tool_input,
237
+ # verbose=self.verbose,
238
+ # color=None,
239
+ # callbacks=run_manager.get_child() if run_manager else None,
240
+ # **tool_run_kwargs,
241
+ # )
242
+ self._state_update_dict({"query": output.tool_input})
243
+ tool_output = self._call_flow_from_state(
244
+ self.exception_flow,
245
+ private_keys=private_keys,
246
+ keys_to_ignore_for_hash=keys_to_ignore_for_hash,
247
+ search_class_namespace_for_inputs=False
248
+ )
249
+ observation = self.get_raw_response(tool_output)
250
+ return [(output, observation)]
251
+
252
+ # If the tool chosen is the finishing tool, then we end and return.
253
+ if isinstance(output, AgentFinish):
254
+ return output
255
+
256
+ actions: List[AgentAction]
257
+ if isinstance(output, AgentAction):
258
+ actions = [output]
259
+ else:
260
+ actions = output
261
+ result = []
262
+ for agent_action in actions:
263
+ # if run_manager:
264
+ # run_manager.on_agent_action(agent_action, color="green")
265
+ # Otherwise we lookup the tool
266
+ if agent_action.tool in self.subflows:
267
+ tool = self.subflows[agent_action.tool]
268
+
269
+ if isinstance(agent_action.tool_input, dict):
270
+ self._state_update_dict(agent_action.tool_input)
271
+ else:
272
+ self._state_update_dict({tool.flow_config["input_keys"][0]:agent_action.tool_input})
273
+
274
+ tool_output = self._call_flow_from_state(
275
+ tool,
276
+ private_keys=private_keys,
277
+ keys_to_ignore_for_hash=keys_to_ignore_for_hash,
278
+ search_class_namespace_for_inputs=False
279
+ )
280
+ observation = self.get_raw_response(tool_output)
281
+ # return_direct = tool.return_direct
282
+ # color = color_mapping[agent_action.tool]
283
+ # tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs()
284
+ # if return_direct:
285
+ # tool_run_kwargs["llm_prefix"] = ""
286
+ # We then call the tool on the tool input to get an observation
287
+ # observation = tool.run(
288
+ # agent_action.tool_input,
289
+ # verbose=self.verbose,
290
+ # color=color,
291
+ # callbacks=run_manager.get_child() if run_manager else None,
292
+ # **tool_run_kwargs,
293
+ # )
294
+ else:
295
+ # tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs()
296
+ # observation = InvalidTool().run(
297
+ # agent_action.tool,
298
+ # verbose=self.verbose,
299
+ # color=None,
300
+ # callbacks=run_manager.get_child() if run_manager else None,
301
+ # **tool_run_kwargs,
302
+ # )
303
+ self._state_update_dict({"tool_name": agent_action.tool})
304
+ tool_output = self._call_flow_from_state(
305
+ self.invalid_flow,
306
+ private_keys=private_keys,
307
+ keys_to_ignore_for_hash=keys_to_ignore_for_hash,
308
+ search_class_namespace_for_inputs=False
309
+ )
310
+ observation = self.get_raw_response(tool_output)
311
+ result.append((agent_action, observation))
312
+ return result
313
+
314
+ def _run(
315
+ self,
316
+ input_data: Dict[str, Any],
317
+ private_keys: Optional[List[str]] = [],
318
+ keys_to_ignore_for_hash: Optional[List[str]] = []
319
+ ) -> str:
320
+ """Run text through and get agent response."""
321
+ # Construct a mapping of tool name to tool for easy lookup
322
+ # name_to_tool_map = {tool.name: tool for tool in self.tools}
323
+ # We construct a mapping from each tool to a color, used for logging.
324
+ # color_mapping = get_color_mapping(
325
+ # [tool.name for tool in self.tools], excluded_colors=["green", "red"]
326
+ # )
327
+ self.flow_state["intermediate_steps"] = []
328
+ intermediate_steps = self.flow_state["intermediate_steps"]
329
+ # Let's start tracking the number of iterations and time elapsed
330
+ iterations = 0
331
+ time_elapsed = 0.0
332
+ start_time = time.time()
333
+ # We now enter the agent loop (until it returns something).
334
+ while self.backend._should_continue(iterations, time_elapsed):
335
+ # next_step_output = self._take_next_step(
336
+ # name_to_tool_map,
337
+ # color_mapping,
338
+ # inputs,
339
+ # intermediate_steps,
340
+ # run_manager=run_manager,
341
+ # )
342
+ next_step_output = self._take_next_step(
343
+ input_data,
344
+ intermediate_steps,
345
+ private_keys,
346
+ keys_to_ignore_for_hash
347
+ )
348
+ if isinstance(next_step_output, AgentFinish):
349
+ # TODO: f"{self.backend.agent.llm_prefix} {next_step_output.log}"
350
+ return next_step_output.return_values["output"]
351
+
352
+ intermediate_steps.extend(next_step_output)
353
+ for act, obs in next_step_output:
354
+ pass # TODO
355
+ # f"{self.backend.agent.llm_prefix} {act.log}"
356
+ # f"{self.backend.agent.observation_prefix}{obs}"
357
+
358
+ if len(next_step_output) == 1:
359
+ next_step_action = next_step_output[0]
360
+ # See if tool should return directly
361
+ tool_return = self.backend._get_tool_return(next_step_action)
362
+ if tool_return is not None:
363
+ # same as the observation
364
+ return tool_return.return_values["output"]
365
+
366
+ iterations += 1
367
+ time_elapsed = time.time() - start_time
368
+
369
+ output = self.backend.agent.return_stopped_response(
370
+ self.backend.early_stopping_method, intermediate_steps, **input_data
371
+ )
372
+ return output.return_values["output"]
373
+
374
+ @flow_run_cache()
375
+ def run(
376
+ self,
377
+ input_data: Dict[str, Any],
378
+ private_keys: Optional[List[str]] = [],
379
+ keys_to_ignore_for_hash: Optional[List[str]] = []
380
+ ) -> Dict[str, Any]:
381
+ self.api_keys = input_data["api_keys"]
382
+ del input_data["api_keys"]
383
+
384
+ llm = ChatOpenAI(
385
+ model_name=self.flow_config["model_name"],
386
+ openai_api_key=self.api_keys["openai"],
387
+ **self.flow_config["generation_parameters"],
388
+ )
389
+ llm_chain = LLMChain(llm=llm, prompt=self.react_prompt_template)
390
+ agent = GenericZeroShotAgent(llm_chain=llm_chain, allowed_tools=list(self.subflows.keys()))
391
+
392
+ self.backend = GenericAgentExecutor.from_agent_and_tools(
393
+ agent=agent,
394
+ tools=self.subflows,
395
+ max_iterations=self.flow_config.get("max_iterations", 15),
396
+ max_execution_time=self.flow_config.get("max_execution_time")
397
+ )
398
+
399
+ data = {k: input_data[k] for k in self.get_input_keys(input_data)}
400
+
401
+ # TODO
402
+ # prompt = UpdateMessage_Generic(
403
+ # created_by=self.flow_config["name"],
404
+ # updated_flow=self.flow_config["name"],
405
+ # content=self._get_prompt_message(data)
406
+ # )
407
+ # self._log_message(prompt)
408
+
409
+ output = self._run(data, private_keys, keys_to_ignore_for_hash)
410
+
411
+ return {input_data["output_keys"][0]: output}
ReAct.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "ReAct_Flow"
2
+ verbose: True
3
+ description: "Flow that implements ReAct logic"
4
+
5
+ model_name: "gpt-4"
6
+ generation_parameters:
7
+ n: 1
8
+ max_tokens: 3000
9
+ temperature: 0.3
10
+
11
+ model_kwargs:
12
+ top_p: 0.2
13
+ frequency_penalty: 0
14
+ presence_penalty: 0
15
+
16
+ max_iterations: 3
17
+ keep_raw_response: True
18
+ clear_flow_namespace_on_run_end: False
19
+
20
+ input_data_transformations: []
21
+ input_keys:
22
+ - "input"
23
+
24
+ output_data_transformations: []
25
+ output_keys:
26
+ - "answer"
27
+
28
+ prompt_config:
29
+ suffix: "Begin! Remember to answer succinctly. The response should include the prefix 'Final Answer: <response>'.\n\nQuestion: {input}\n{agent_scratchpad}"
30
+
31
+ subflows_config:
32
+ - _target_: flows.base_flows.GenericLCTool.instantiate_from_config
33
+ config:
34
+ name: "Search"
35
+ verbose: True
36
+ description: "useful when you need to answer questions about current events"
37
+
38
+ tool_type: "wikipedia"
39
+ return_direct: False
40
+
41
+ keep_raw_response: True
42
+ clear_flow_namespace_on_run_end: False
43
+
44
+ input_data_transformations: []
45
+ input_keys:
46
+ - "tool_input"
47
+
48
+ output_data_transformations: []
49
+ output_keys:
50
+ - "observation"