Spaces:
Runtime error
Runtime error
File size: 5,556 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
from typing import Any, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.react.output_parser import ReActOutputParser
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.tools.base import BaseTool
class ReActDocstoreAgent(Agent):
"""Agent for the ReAct chain."""
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
@classmethod
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ReActOutputParser()
@property
def _agent_type(self) -> str:
"""Return Identifier of an agent type."""
return AgentType.REACT_DOCSTORE
@classmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt."""
return WIKI_PROMPT
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
if tool_names != {"Lookup", "Search"}:
raise ValueError(
f"Tool names should be Lookup and Search, got {tool_names}"
)
@property
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
return "Observation: "
@property
def _stop(self) -> List[str]:
return ["\nObservation:"]
@property
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
return "Thought:"
class DocstoreExplorer:
"""Class to assist with exploration of a document store."""
def __init__(self, docstore: Docstore):
"""Initialize with a docstore, and set initial document to None."""
self.docstore = docstore
self.document: Optional[Document] = None
self.lookup_str = ""
self.lookup_index = 0
def search(self, term: str) -> str:
"""Search for a term in the docstore, and if found save."""
result = self.docstore.search(term)
if isinstance(result, Document):
self.document = result
return self._summary
else:
self.document = None
return result
def lookup(self, term: str) -> str:
"""Lookup a term in document (if saved)."""
if self.document is None:
raise ValueError("Cannot lookup without a successful search first")
if term.lower() != self.lookup_str:
self.lookup_str = term.lower()
self.lookup_index = 0
else:
self.lookup_index += 1
lookups = [p for p in self._paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"
@property
def _summary(self) -> str:
return self._paragraphs[0]
@property
def _paragraphs(self) -> List[str]:
if self.document is None:
raise ValueError("Cannot get paragraphs without a document")
return self.document.page_content.split("\n\n")
class ReActTextWorldAgent(ReActDocstoreAgent):
"""Agent for the ReAct TextWorld chain."""
@classmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt."""
return TEXTWORLD_PROMPT
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
if tool_names != {"Play"}:
raise ValueError(f"Tool name should be Play, got {tool_names}")
class ReActChain(AgentExecutor):
"""[Deprecated] Chain that implements the ReAct paper."""
def __init__(self, llm: BaseLanguageModel, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore."""
docstore_explorer = DocstoreExplorer(docstore)
tools = [
Tool(
name="Search",
func=docstore_explorer.search,
description="Search for a term in the docstore.",
),
Tool(
name="Lookup",
func=docstore_explorer.lookup,
description="Lookup a term in the docstore.",
),
]
agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools)
super().__init__(agent=agent, tools=tools, **kwargs)
|