Spaces:
Runtime error
Runtime error
File size: 3,286 Bytes
58d33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""Chain that does self ask with search."""
from typing import Any, Optional, Sequence, Tuple, Union
from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
class SelfAskWithSearchAgent(Agent):
"""Agent for the self-ask-with-search paper."""
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
return "self-ask-with-search"
@classmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Prompt does not depend on tools."""
return PROMPT
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
if tool_names != {"Intermediate Answer"}:
raise ValueError(
f"Tool name should be Intermediate Answer, got {tool_names}"
)
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
followup = "Follow up:"
last_line = text.split("\n")[-1]
if followup not in last_line:
finish_string = "So the final answer is: "
if finish_string not in last_line:
return None
return "Final Answer", last_line[len(finish_string) :]
after_colon = text.split(":")[-1]
if " " == after_colon[0]:
after_colon = after_colon[1:]
return "Intermediate Answer", after_colon
def _fix_text(self, text: str) -> str:
return f"{text}\nSo the final answer is:"
@property
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
return "Intermediate answer: "
@property
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
return ""
@property
def starter_string(self) -> str:
"""Put this string after user input but before first LLM call."""
return "Are follow up questions needed here:"
class SelfAskWithSearchChain(AgentExecutor):
"""Chain that does self ask with search.
Example:
.. code-block:: python
from langchain import SelfAskWithSearchChain, OpenAI, GoogleSerperAPIWrapper
search_chain = GoogleSerperAPIWrapper()
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
"""
def __init__(
self,
llm: BaseLLM,
search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper],
**kwargs: Any,
):
"""Initialize with just an LLM and a search chain."""
search_tool = Tool(
name="Intermediate Answer", func=search_chain.run, description="Search"
)
agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool])
super().__init__(agent=agent, tools=[search_tool], **kwargs)
|