"""Tools for interacting with vectorstores.""" import json from typing import Any, Dict from pydantic import BaseModel, Field from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.retrieval_qa.base import VectorDBQA from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI from langchain.tools.base import BaseTool from langchain.vectorstores.base import VectorStore class BaseVectorStoreTool(BaseModel): """Base class for tools that use a VectorStore.""" vectorstore: VectorStore = Field(exclude=True) llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0)) class Config(BaseTool.Config): """Configuration for this pydantic object.""" arbitrary_types_allowed = True def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]: values["description"] = values["template"].format(name=values["name"]) return values class VectorStoreQATool(BaseVectorStoreTool, BaseTool): """Tool for the VectorDBQA chain. To be initialized with name and chain.""" @staticmethod def get_description(name: str, description: str) -> str: template: str = ( "Useful for when you need to answer questions about {name}. " "Whenever you need information about {description} " "you should ALWAYS use this. " "Input should be a fully formed question." ) return template.format(name=name, description=description) def _run(self, query: str) -> str: """Use the tool.""" chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore) return chain.run(query) async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("VectorDBQATool does not support async") class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): """Tool for the VectorDBQAWithSources chain.""" @staticmethod def get_description(name: str, description: str) -> str: template: str = ( "Useful for when you need to answer questions about {name} and the sources " "used to construct the answer. " "Whenever you need information about {description} " "you should ALWAYS use this. " " Input should be a fully formed question. " "Output is a json serialized dictionary with keys `answer` and `sources`. " "Only use this tool if the user explicitly asks for sources." ) return template.format(name=name, description=description) def _run(self, query: str) -> str: """Use the tool.""" chain = VectorDBQAWithSourcesChain.from_chain_type( self.llm, vectorstore=self.vectorstore ) return json.dumps(chain({chain.question_key: query}, return_only_outputs=True)) async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("VectorDBQATool does not support async")