from typing import Dict, Any, Callable from haystack import Pipeline from haystack.agents.base import ToolsManager from haystack.nodes import PromptNode, SentenceTransformersRanker from haystack.agents import Agent, Tool from service.utils.memory_node import return_memory_node from service.utils.prompts import agent_prompt from service.utils.retriever import return_retriever def resolver_function( query: str, agent: Agent, agent_step: Callable, ) -> Dict[str, Any]: """ This function is used to resolve the parameters of the prompt template. :param query: the query :param agent: the agent :param agent_step: the agent step :return: a dictionary of parameters """ return { 'query': query, 'tool_names_with_descriptions': agent.tm.get_tool_names_with_descriptions(), 'transcript': agent_step.transcript, 'memory': agent.memory.load(), } def define_haystack_doc_searcher_tool() -> Tool: """ Defines the tool for searching the Haystack documentation. :return: the Haystack documentation searcher tool """ ranker = SentenceTransformersRanker(model_name_or_path='cross-encoder/ms-marco-MiniLM-L-12-v2', top_k=5) retriever = return_retriever() haystack_docs = Pipeline() haystack_docs.add_node(component=retriever, name='retriever', inputs=['Query']) haystack_docs.add_node(component=ranker, name='ranker', inputs=['retriever']) return Tool( name='haystack_documentation_search_tool', pipeline_or_node=haystack_docs, description='Searches the Haystack documentation for information.', output_variable='documents', ) def return_haystack_documentation_agent(openai_key: str) -> Agent: """ Returns an agent that can answer questions about the Haystack documentation. :param openai_key: the OpenAI key :return: the agent """ agent_prompt_node = PromptNode( 'gpt-3.5-turbo-16k', api_key=openai_key, stop_words=['Observation:'], model_kwargs={'temperature': 0.05}, max_length=10000, ) agent = Agent( agent_prompt_node, prompt_template=agent_prompt, prompt_parameters_resolver=resolver_function, memory=return_memory_node(openai_key), tools_manager=ToolsManager([define_haystack_doc_searcher_tool()]), final_answer_pattern=r"(?s)Final Answer\s*:\s*(.*)", ) return agent