Spaces:
Sleeping
Sleeping
from typing import Any, List | |
from canonical_demo_memory.core.debug import FakeChatModel | |
from canonical_demo_memory.core.embedding import FolderIndex | |
from canonical_demo_memory.core.prompts import STUFF_PROMPT | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.docstore.document import Document | |
from pydantic import BaseModel | |
class AnswerWithSources(BaseModel): | |
answer: str | |
sources: List[Document] | |
def query_folder( | |
query: str, | |
folder_index: FolderIndex, | |
return_all: bool = False, | |
model: str = "openai", | |
**model_kwargs: Any, | |
) -> AnswerWithSources: | |
"""Queries a folder index for an answer. | |
Args: | |
query (str): The query to search for. | |
folder_index (FolderIndex): The folder index to search. | |
return_all (bool): Whether to return all the documents from the embedding or | |
just the sources for the answer. | |
model (str): The model to use for the answer generation. | |
**model_kwargs (Any): Keyword arguments for the model. | |
Returns: | |
AnswerWithSources: The answer and the source documents. | |
""" | |
supported_models = { | |
"openai": ChatOpenAI, | |
"debug": FakeChatModel, | |
} | |
if model in supported_models: | |
llm = supported_models[model](**model_kwargs) | |
else: | |
raise ValueError(f"Model {model} not supported.") | |
chain = load_qa_with_sources_chain( | |
llm=llm, | |
chain_type="stuff", | |
prompt=STUFF_PROMPT, | |
) | |
relevant_docs = folder_index.index.similarity_search(query, k=5) | |
result = chain( | |
{"input_documents": relevant_docs, "question": query}, return_only_outputs=True | |
) | |
sources = relevant_docs | |
if not return_all: | |
sources = get_sources(result["output_text"], folder_index) | |
answer = result["output_text"].split("SOURCES: ")[0] | |
return AnswerWithSources(answer=answer, sources=sources) | |
def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]: | |
"""Retrieves the docs that were used to answer the question the generated answer.""" | |
source_keys = [s for s in answer.split("SOURCES: ")[-1].split(", ")] | |
source_docs = [] | |
for file in folder_index.files: | |
for doc in file.docs: | |
if doc.metadata["source"] in source_keys: | |
source_docs.append(doc) | |
return source_docs | |