|
from typing import Any, Dict, List, Optional, Type |
|
|
|
from langchain_core.document_loaders import BaseLoader |
|
from langchain_core.documents import Document |
|
from langchain_core.embeddings import Embeddings |
|
from langchain_core.language_models import BaseLanguageModel |
|
from langchain_core.vectorstores import VectorStore |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter |
|
from pydantic import BaseModel, ConfigDict, Field |
|
|
|
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain |
|
from langchain.chains.retrieval_qa.base import RetrievalQA |
|
|
|
|
|
def _get_default_text_splitter() -> TextSplitter: |
|
return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
|
|
|
|
class VectorStoreIndexWrapper(BaseModel): |
|
"""Wrapper around a vectorstore for easy access.""" |
|
|
|
vectorstore: VectorStore |
|
|
|
model_config = ConfigDict( |
|
arbitrary_types_allowed=True, |
|
extra="forbid", |
|
) |
|
|
|
def query( |
|
self, |
|
question: str, |
|
llm: Optional[BaseLanguageModel] = None, |
|
retriever_kwargs: Optional[Dict[str, Any]] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
"""Query the vectorstore.""" |
|
if llm is None: |
|
raise NotImplementedError( |
|
"This API has been changed to require an LLM. " |
|
"Please provide an llm to use for querying the vectorstore.\n" |
|
"For example,\n" |
|
"from langchain_openai import OpenAI\n" |
|
"llm = OpenAI(temperature=0)" |
|
) |
|
retriever_kwargs = retriever_kwargs or {} |
|
chain = RetrievalQA.from_chain_type( |
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs |
|
) |
|
return chain.invoke({chain.input_key: question})[chain.output_key] |
|
|
|
async def aquery( |
|
self, |
|
question: str, |
|
llm: Optional[BaseLanguageModel] = None, |
|
retriever_kwargs: Optional[Dict[str, Any]] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
"""Query the vectorstore.""" |
|
if llm is None: |
|
raise NotImplementedError( |
|
"This API has been changed to require an LLM. " |
|
"Please provide an llm to use for querying the vectorstore.\n" |
|
"For example,\n" |
|
"from langchain_openai import OpenAI\n" |
|
"llm = OpenAI(temperature=0)" |
|
) |
|
retriever_kwargs = retriever_kwargs or {} |
|
chain = RetrievalQA.from_chain_type( |
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs |
|
) |
|
return (await chain.ainvoke({chain.input_key: question}))[chain.output_key] |
|
|
|
def query_with_sources( |
|
self, |
|
question: str, |
|
llm: Optional[BaseLanguageModel] = None, |
|
retriever_kwargs: Optional[Dict[str, Any]] = None, |
|
**kwargs: Any, |
|
) -> dict: |
|
"""Query the vectorstore and get back sources.""" |
|
if llm is None: |
|
raise NotImplementedError( |
|
"This API has been changed to require an LLM. " |
|
"Please provide an llm to use for querying the vectorstore.\n" |
|
"For example,\n" |
|
"from langchain_openai import OpenAI\n" |
|
"llm = OpenAI(temperature=0)" |
|
) |
|
retriever_kwargs = retriever_kwargs or {} |
|
chain = RetrievalQAWithSourcesChain.from_chain_type( |
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs |
|
) |
|
return chain.invoke({chain.question_key: question}) |
|
|
|
async def aquery_with_sources( |
|
self, |
|
question: str, |
|
llm: Optional[BaseLanguageModel] = None, |
|
retriever_kwargs: Optional[Dict[str, Any]] = None, |
|
**kwargs: Any, |
|
) -> dict: |
|
"""Query the vectorstore and get back sources.""" |
|
if llm is None: |
|
raise NotImplementedError( |
|
"This API has been changed to require an LLM. " |
|
"Please provide an llm to use for querying the vectorstore.\n" |
|
"For example,\n" |
|
"from langchain_openai import OpenAI\n" |
|
"llm = OpenAI(temperature=0)" |
|
) |
|
retriever_kwargs = retriever_kwargs or {} |
|
chain = RetrievalQAWithSourcesChain.from_chain_type( |
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs |
|
) |
|
return await chain.ainvoke({chain.question_key: question}) |
|
|
|
|
|
def _get_in_memory_vectorstore() -> Type[VectorStore]: |
|
"""Get the InMemoryVectorStore.""" |
|
import warnings |
|
|
|
try: |
|
from langchain_community.vectorstores.inmemory import InMemoryVectorStore |
|
except ImportError: |
|
raise ImportError( |
|
"Please install langchain-community to use the InMemoryVectorStore." |
|
) |
|
warnings.warn( |
|
"Using InMemoryVectorStore as the default vectorstore." |
|
"This memory store won't persist data. You should explicitly" |
|
"specify a vectorstore when using VectorstoreIndexCreator" |
|
) |
|
return InMemoryVectorStore |
|
|
|
|
|
class VectorstoreIndexCreator(BaseModel): |
|
"""Logic for creating indexes.""" |
|
|
|
vectorstore_cls: Type[VectorStore] = Field( |
|
default_factory=_get_in_memory_vectorstore |
|
) |
|
embedding: Embeddings |
|
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) |
|
vectorstore_kwargs: dict = Field(default_factory=dict) |
|
|
|
model_config = ConfigDict( |
|
arbitrary_types_allowed=True, |
|
extra="forbid", |
|
) |
|
|
|
def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: |
|
"""Create a vectorstore index from loaders.""" |
|
docs = [] |
|
for loader in loaders: |
|
docs.extend(loader.load()) |
|
return self.from_documents(docs) |
|
|
|
async def afrom_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: |
|
"""Create a vectorstore index from loaders.""" |
|
docs = [] |
|
for loader in loaders: |
|
async for doc in loader.alazy_load(): |
|
docs.append(doc) |
|
return await self.afrom_documents(docs) |
|
|
|
def from_documents(self, documents: List[Document]) -> VectorStoreIndexWrapper: |
|
"""Create a vectorstore index from documents.""" |
|
sub_docs = self.text_splitter.split_documents(documents) |
|
vectorstore = self.vectorstore_cls.from_documents( |
|
sub_docs, self.embedding, **self.vectorstore_kwargs |
|
) |
|
return VectorStoreIndexWrapper(vectorstore=vectorstore) |
|
|
|
async def afrom_documents( |
|
self, documents: List[Document] |
|
) -> VectorStoreIndexWrapper: |
|
"""Create a vectorstore index from documents.""" |
|
sub_docs = self.text_splitter.split_documents(documents) |
|
vectorstore = await self.vectorstore_cls.afrom_documents( |
|
sub_docs, self.embedding, **self.vectorstore_kwargs |
|
) |
|
return VectorStoreIndexWrapper(vectorstore=vectorstore) |
|
|