jfeng1115's picture
init commit
58d33f0
raw
history blame contribute delete
No virus
3.08 kB
"""Tools for interacting with vectorstores."""
import json
from typing import Any, Dict
from pydantic import BaseModel, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.tools.base import BaseTool
from langchain.vectorstores.base import VectorStore
class BaseVectorStoreTool(BaseModel):
"""Base class for tools that use a VectorStore."""
vectorstore: VectorStore = Field(exclude=True)
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
values["description"] = values["template"].format(name=values["name"])
return values
class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
"""Tool for the VectorDBQA chain. To be initialized with name and chain."""
@staticmethod
def get_description(name: str, description: str) -> str:
template: str = (
"Useful for when you need to answer questions about {name}. "
"Whenever you need information about {description} "
"you should ALWAYS use this. "
"Input should be a fully formed question."
)
return template.format(name=name, description=description)
def _run(self, query: str) -> str:
"""Use the tool."""
chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore)
return chain.run(query)
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("VectorDBQATool does not support async")
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
"""Tool for the VectorDBQAWithSources chain."""
@staticmethod
def get_description(name: str, description: str) -> str:
template: str = (
"Useful for when you need to answer questions about {name} and the sources "
"used to construct the answer. "
"Whenever you need information about {description} "
"you should ALWAYS use this. "
" Input should be a fully formed question. "
"Output is a json serialized dictionary with keys `answer` and `sources`. "
"Only use this tool if the user explicitly asks for sources."
)
return template.format(name=name, description=description)
def _run(self, query: str) -> str:
"""Use the tool."""
chain = VectorDBQAWithSourcesChain.from_chain_type(
self.llm, vectorstore=self.vectorstore
)
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("VectorDBQATool does not support async")