from typing import Any, List from canonical_demo_memory.core.debug import FakeChatModel from canonical_demo_memory.core.embedding import FolderIndex from canonical_demo_memory.core.prompts import STUFF_PROMPT from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.chat_models import ChatOpenAI from langchain.docstore.document import Document from pydantic import BaseModel class AnswerWithSources(BaseModel): answer: str sources: List[Document] def query_folder( query: str, folder_index: FolderIndex, return_all: bool = False, model: str = "openai", **model_kwargs: Any, ) -> AnswerWithSources: """Queries a folder index for an answer. Args: query (str): The query to search for. folder_index (FolderIndex): The folder index to search. return_all (bool): Whether to return all the documents from the embedding or just the sources for the answer. model (str): The model to use for the answer generation. **model_kwargs (Any): Keyword arguments for the model. Returns: AnswerWithSources: The answer and the source documents. """ supported_models = { "openai": ChatOpenAI, "debug": FakeChatModel, } if model in supported_models: llm = supported_models[model](**model_kwargs) else: raise ValueError(f"Model {model} not supported.") chain = load_qa_with_sources_chain( llm=llm, chain_type="stuff", prompt=STUFF_PROMPT, ) relevant_docs = folder_index.index.similarity_search(query, k=5) result = chain( {"input_documents": relevant_docs, "question": query}, return_only_outputs=True ) sources = relevant_docs if not return_all: sources = get_sources(result["output_text"], folder_index) answer = result["output_text"].split("SOURCES: ")[0] return AnswerWithSources(answer=answer, sources=sources) def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]: """Retrieves the docs that were used to answer the question the generated answer.""" source_keys = [s for s in answer.split("SOURCES: ")[-1].split(", ")] source_docs = [] for file in folder_index.files: for doc in file.docs: if doc.metadata["source"] in source_keys: source_docs.append(doc) return source_docs