File size: 3,286 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Chain that does self ask with search."""
from typing import Any, Optional, Sequence, Tuple, Union

from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.serpapi import SerpAPIWrapper


class SelfAskWithSearchAgent(Agent):
    """Agent for the self-ask-with-search paper."""

    @property
    def _agent_type(self) -> str:
        """Return Identifier of agent type."""
        return "self-ask-with-search"

    @classmethod
    def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
        """Prompt does not depend on tools."""
        return PROMPT

    @classmethod
    def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
        if len(tools) != 1:
            raise ValueError(f"Exactly one tool must be specified, but got {tools}")
        tool_names = {tool.name for tool in tools}
        if tool_names != {"Intermediate Answer"}:
            raise ValueError(
                f"Tool name should be Intermediate Answer, got {tool_names}"
            )

    def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
        followup = "Follow up:"
        last_line = text.split("\n")[-1]

        if followup not in last_line:
            finish_string = "So the final answer is: "
            if finish_string not in last_line:
                return None
            return "Final Answer", last_line[len(finish_string) :]

        after_colon = text.split(":")[-1]

        if " " == after_colon[0]:
            after_colon = after_colon[1:]

        return "Intermediate Answer", after_colon

    def _fix_text(self, text: str) -> str:
        return f"{text}\nSo the final answer is:"

    @property
    def observation_prefix(self) -> str:
        """Prefix to append the observation with."""
        return "Intermediate answer: "

    @property
    def llm_prefix(self) -> str:
        """Prefix to append the LLM call with."""
        return ""

    @property
    def starter_string(self) -> str:
        """Put this string after user input but before first LLM call."""
        return "Are follow up questions needed here:"


class SelfAskWithSearchChain(AgentExecutor):
    """Chain that does self ask with search.

    Example:
        .. code-block:: python

            from langchain import SelfAskWithSearchChain, OpenAI, GoogleSerperAPIWrapper
            search_chain = GoogleSerperAPIWrapper()
            self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
    """

    def __init__(
        self,
        llm: BaseLLM,
        search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper],
        **kwargs: Any,
    ):
        """Initialize with just an LLM and a search chain."""
        search_tool = Tool(
            name="Intermediate Answer", func=search_chain.run, description="Search"
        )
        agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool])
        super().__init__(agent=agent, tools=[search_tool], **kwargs)