Spaces:
Running
Running
| import os | |
| import sys | |
| import nest_asyncio | |
| import Stemmer | |
| from llama_index.core import ( | |
| PromptTemplate, | |
| Settings, | |
| SimpleDirectoryReader, | |
| StorageContext, | |
| VectorStoreIndex, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.core.query_engine import CitationQueryEngine | |
| from llama_index.core.retrievers import QueryFusionRetriever | |
| from llama_index.core.schema import NodeWithScore as NodeWithScore | |
| from llama_index.embeddings.google import GeminiEmbedding | |
| from llama_index.llms.gemini import Gemini | |
| from llama_index.retrievers.bm25 import BM25Retriever | |
| import mesop as me | |
| nest_asyncio.apply() | |
| CITATION_QA_TEMPLATE = PromptTemplate( | |
| "Please provide an answer based solely on the provided sources. " | |
| "When referencing information from a source, " | |
| "cite the appropriate source(s) using their corresponding numbers. " | |
| "Every answer should include at least one source citation. " | |
| "Only cite a source when you are explicitly referencing it. " | |
| "If you are sure NONE of the sources are helpful, then say: 'Sorry, I didn't find any docs about this.'" | |
| "If you are not sure if any of the sources are helpful, then say: 'You might find this helpful', where 'this' is the source's title.'" | |
| "DO NOT say Source 1, Source 2, etc. Only reference sources like this: [1], [2], etc." | |
| "I want you to pick just ONE source to answer the question." | |
| "For example:\n" | |
| "Source 1:\n" | |
| "The sky is red in the evening and blue in the morning.\n" | |
| "Source 2:\n" | |
| "Water is wet when the sky is red.\n" | |
| "Query: When is water wet?\n" | |
| "Answer: Water will be wet when the sky is red [2], " | |
| "which occurs in the evening [1].\n" | |
| "Now it's your turn. Below are several numbered sources of information:" | |
| "\n------\n" | |
| "{context_str}" | |
| "\n------\n" | |
| "Query: {query_str}\n" | |
| "Answer: " | |
| ) | |
| os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"] | |
| def get_meta(file_path: str) -> dict[str, str]: | |
| with open(file_path) as f: | |
| title = f.readline().strip() | |
| if title.startswith("# "): | |
| title = title[2:] | |
| else: | |
| title = ( | |
| file_path.split("/")[-1] | |
| .replace(".md", "") | |
| .replace("-", " ") | |
| .capitalize() | |
| ) | |
| file_path = file_path.replace(".md", "") | |
| CONST = "../../docs/" | |
| docs_index = file_path.index(CONST) | |
| docs_path = file_path[docs_index + len(CONST) :] | |
| url = "https://mesop-dev.github.io/mesop/" + docs_path | |
| print(f"URL: {url}") | |
| return { | |
| "url": url, | |
| "title": title, | |
| } | |
| embed_model = GeminiEmbedding( | |
| model_name="models/text-embedding-004", api_key=os.environ["GOOGLE_API_KEY"] | |
| ) | |
| Settings.embed_model = embed_model | |
| PERSIST_DIR = "./gen" | |
| def build_or_load_index(): | |
| if not os.path.exists(PERSIST_DIR) or "--build-index" in sys.argv: | |
| print("Building index") | |
| documents = SimpleDirectoryReader( | |
| "../../docs/", | |
| required_exts=[ | |
| ".md", | |
| ], | |
| exclude=[ | |
| "showcase.md", | |
| "demo.md", | |
| "blog", | |
| "internal", | |
| ], | |
| file_metadata=get_meta, | |
| recursive=True, | |
| ).load_data() | |
| for doc in documents: | |
| doc.excluded_llm_metadata_keys = ["url"] | |
| splitter = SentenceSplitter(chunk_size=512) | |
| nodes = splitter.get_nodes_from_documents(documents) | |
| bm25_retriever = BM25Retriever.from_defaults( | |
| nodes=nodes, | |
| similarity_top_k=5, | |
| # Optional: We can pass in the stemmer and set the language for stopwords | |
| # This is important for removing stopwords and stemming the query + text | |
| # The default is english for both | |
| stemmer=Stemmer.Stemmer("english"), | |
| language="english", | |
| ) | |
| bm25_retriever.persist(PERSIST_DIR + "/bm25_retriever") | |
| index = VectorStoreIndex.from_documents(documents, embed_model=embed_model) | |
| index.storage_context.persist(persist_dir=PERSIST_DIR) | |
| return index, bm25_retriever | |
| else: | |
| print("Loading index") | |
| bm25_retriever = BM25Retriever.from_persist_dir( | |
| PERSIST_DIR + "/bm25_retriever" | |
| ) | |
| storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
| index = load_index_from_storage(storage_context) | |
| return index, bm25_retriever | |
| if me.runtime().is_hot_reload_in_progress: | |
| print("Hot reload - skip building index!") | |
| query_engine = me._query_engine | |
| bm25_retriever = me._bm25_retriever | |
| else: | |
| index, bm25_retriever = build_or_load_index() | |
| llm = Gemini(model="models/gemini-flash-latest") | |
| retriever = QueryFusionRetriever( | |
| [ | |
| index.as_retriever(similarity_top_k=5), | |
| bm25_retriever, | |
| ], | |
| llm=llm, | |
| num_queries=1, | |
| use_async=True, | |
| similarity_top_k=5, | |
| ) | |
| query_engine = CitationQueryEngine.from_args( | |
| index, | |
| retriever=retriever, | |
| llm=llm, | |
| citation_qa_template=CITATION_QA_TEMPLATE, | |
| similarity_top_k=5, | |
| embedding_model=embed_model, | |
| streaming=True, | |
| ) | |
| blocking_query_engine = CitationQueryEngine.from_args( | |
| index, | |
| retriever=retriever, | |
| llm=llm, | |
| citation_qa_template=CITATION_QA_TEMPLATE, | |
| similarity_top_k=5, | |
| embedding_model=embed_model, | |
| streaming=False, | |
| ) | |
| # TODO: replace with proper mechanism for persisting objects | |
| # across hot reloads | |
| me._query_engine = query_engine | |
| me._bm25_retriever = bm25_retriever | |
| NEWLINE = "\n" | |
| def ask(query: str): | |
| return query_engine.query(query) | |
| def retrieve(query: str): | |
| return bm25_retriever.retrieve(query) | |