Spaces:
Runtime error
Runtime error
"""Chain that does self ask with search.""" | |
from typing import Any, Optional, Sequence, Tuple, Union | |
from langchain.agents.agent import Agent, AgentExecutor | |
from langchain.agents.self_ask_with_search.prompt import PROMPT | |
from langchain.agents.tools import Tool | |
from langchain.llms.base import BaseLLM | |
from langchain.prompts.base import BasePromptTemplate | |
from langchain.tools.base import BaseTool | |
from langchain.utilities.google_serper import GoogleSerperAPIWrapper | |
from langchain.utilities.serpapi import SerpAPIWrapper | |
class SelfAskWithSearchAgent(Agent): | |
"""Agent for the self-ask-with-search paper.""" | |
def _agent_type(self) -> str: | |
"""Return Identifier of agent type.""" | |
return "self-ask-with-search" | |
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: | |
"""Prompt does not depend on tools.""" | |
return PROMPT | |
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: | |
if len(tools) != 1: | |
raise ValueError(f"Exactly one tool must be specified, but got {tools}") | |
tool_names = {tool.name for tool in tools} | |
if tool_names != {"Intermediate Answer"}: | |
raise ValueError( | |
f"Tool name should be Intermediate Answer, got {tool_names}" | |
) | |
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: | |
followup = "Follow up:" | |
last_line = text.split("\n")[-1] | |
if followup not in last_line: | |
finish_string = "So the final answer is: " | |
if finish_string not in last_line: | |
return None | |
return "Final Answer", last_line[len(finish_string) :] | |
after_colon = text.split(":")[-1] | |
if " " == after_colon[0]: | |
after_colon = after_colon[1:] | |
return "Intermediate Answer", after_colon | |
def _fix_text(self, text: str) -> str: | |
return f"{text}\nSo the final answer is:" | |
def observation_prefix(self) -> str: | |
"""Prefix to append the observation with.""" | |
return "Intermediate answer: " | |
def llm_prefix(self) -> str: | |
"""Prefix to append the LLM call with.""" | |
return "" | |
def starter_string(self) -> str: | |
"""Put this string after user input but before first LLM call.""" | |
return "Are follow up questions needed here:" | |
class SelfAskWithSearchChain(AgentExecutor): | |
"""Chain that does self ask with search. | |
Example: | |
.. code-block:: python | |
from langchain import SelfAskWithSearchChain, OpenAI, GoogleSerperAPIWrapper | |
search_chain = GoogleSerperAPIWrapper() | |
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) | |
""" | |
def __init__( | |
self, | |
llm: BaseLLM, | |
search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper], | |
**kwargs: Any, | |
): | |
"""Initialize with just an LLM and a search chain.""" | |
search_tool = Tool( | |
name="Intermediate Answer", func=search_chain.run, description="Search" | |
) | |
agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool]) | |
super().__init__(agent=agent, tools=[search_tool], **kwargs) | |