unofficial-haystack-documentation-agent / service /haystack_documentation_pipeline.py
tsadoq's picture
Upload 10 files
40072a5
raw
history blame
No virus
2.46 kB
from typing import Dict, Any, Callable
from haystack import Pipeline
from haystack.agents.base import ToolsManager
from haystack.nodes import PromptNode, SentenceTransformersRanker
from haystack.agents import Agent, Tool
from service.utils.memory_node import return_memory_node
from service.utils.prompts import agent_prompt
from service.utils.retriever import return_retriever
def resolver_function(
query: str,
agent: Agent,
agent_step: Callable,
) -> Dict[str, Any]:
"""
This function is used to resolve the parameters of the prompt template.
:param query: the query
:param agent: the agent
:param agent_step: the agent step
:return: a dictionary of parameters
"""
return {
'query': query,
'tool_names_with_descriptions': agent.tm.get_tool_names_with_descriptions(),
'transcript': agent_step.transcript,
'memory': agent.memory.load(),
}
def define_haystack_doc_searcher_tool() -> Tool:
"""
Defines the tool for searching the Haystack documentation.
:return: the Haystack documentation searcher tool
"""
ranker = SentenceTransformersRanker(model_name_or_path='cross-encoder/ms-marco-MiniLM-L-12-v2', top_k=5)
retriever = return_retriever()
haystack_docs = Pipeline()
haystack_docs.add_node(component=retriever, name='retriever', inputs=['Query'])
haystack_docs.add_node(component=ranker, name='ranker', inputs=['retriever'])
return Tool(
name='haystack_documentation_search_tool',
pipeline_or_node=haystack_docs,
description='Searches the Haystack documentation for information.',
output_variable='documents',
)
def return_haystack_documentation_agent(openai_key: str) -> Agent:
"""
Returns an agent that can answer questions about the Haystack documentation.
:param openai_key: the OpenAI key
:return: the agent
"""
agent_prompt_node = PromptNode(
'gpt-3.5-turbo-16k',
api_key=openai_key,
stop_words=['Observation:'],
model_kwargs={'temperature': 0.05},
max_length=10000,
)
agent = Agent(
agent_prompt_node,
prompt_template=agent_prompt,
prompt_parameters_resolver=resolver_function,
memory=return_memory_node(openai_key),
tools_manager=ToolsManager([define_haystack_doc_searcher_tool()]),
final_answer_pattern=r"(?s)Final Answer\s*:\s*(.*)",
)
return agent