Spaces:
Running
Running
import os | |
import sys | |
import nest_asyncio | |
import Stemmer | |
from llama_index.core import ( | |
PromptTemplate, | |
Settings, | |
SimpleDirectoryReader, | |
StorageContext, | |
VectorStoreIndex, | |
load_index_from_storage, | |
) | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.query_engine import CitationQueryEngine | |
from llama_index.core.retrievers import QueryFusionRetriever | |
from llama_index.core.schema import NodeWithScore as NodeWithScore | |
from llama_index.embeddings.google import GeminiEmbedding | |
from llama_index.llms.gemini import Gemini | |
from llama_index.retrievers.bm25 import BM25Retriever | |
import mesop as me | |
nest_asyncio.apply() | |
CITATION_QA_TEMPLATE = PromptTemplate( | |
"Please provide an answer based solely on the provided sources. " | |
"When referencing information from a source, " | |
"cite the appropriate source(s) using their corresponding numbers. " | |
"Every answer should include at least one source citation. " | |
"Only cite a source when you are explicitly referencing it. " | |
"If you are sure NONE of the sources are helpful, then say: 'Sorry, I didn't find any docs about this.'" | |
"If you are not sure if any of the sources are helpful, then say: 'You might find this helpful', where 'this' is the source's title.'" | |
"DO NOT say Source 1, Source 2, etc. Only reference sources like this: [1], [2], etc." | |
"I want you to pick just ONE source to answer the question." | |
"For example:\n" | |
"Source 1:\n" | |
"The sky is red in the evening and blue in the morning.\n" | |
"Source 2:\n" | |
"Water is wet when the sky is red.\n" | |
"Query: When is water wet?\n" | |
"Answer: Water will be wet when the sky is red [2], " | |
"which occurs in the evening [1].\n" | |
"Now it's your turn. Below are several numbered sources of information:" | |
"\n------\n" | |
"{context_str}" | |
"\n------\n" | |
"Query: {query_str}\n" | |
"Answer: " | |
) | |
os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"] | |
def get_meta(file_path: str) -> dict[str, str]: | |
with open(file_path) as f: | |
title = f.readline().strip() | |
if title.startswith("# "): | |
title = title[2:] | |
else: | |
title = ( | |
file_path.split("/")[-1] | |
.replace(".md", "") | |
.replace("-", " ") | |
.capitalize() | |
) | |
file_path = file_path.replace(".md", "") | |
CONST = "../../docs/" | |
docs_index = file_path.index(CONST) | |
docs_path = file_path[docs_index + len(CONST) :] | |
url = "https://mesop-dev.github.io/mesop/" + docs_path | |
print(f"URL: {url}") | |
return { | |
"url": url, | |
"title": title, | |
} | |
embed_model = GeminiEmbedding( | |
model_name="models/text-embedding-004", api_key=os.environ["GOOGLE_API_KEY"] | |
) | |
Settings.embed_model = embed_model | |
PERSIST_DIR = "./gen" | |
def build_or_load_index(): | |
if not os.path.exists(PERSIST_DIR) or "--build-index" in sys.argv: | |
print("Building index") | |
documents = SimpleDirectoryReader( | |
"../../docs/", | |
required_exts=[ | |
".md", | |
], | |
exclude=[ | |
"showcase.md", | |
"demo.md", | |
"blog", | |
"internal", | |
], | |
file_metadata=get_meta, | |
recursive=True, | |
).load_data() | |
for doc in documents: | |
doc.excluded_llm_metadata_keys = ["url"] | |
splitter = SentenceSplitter(chunk_size=512) | |
nodes = splitter.get_nodes_from_documents(documents) | |
bm25_retriever = BM25Retriever.from_defaults( | |
nodes=nodes, | |
similarity_top_k=5, | |
# Optional: We can pass in the stemmer and set the language for stopwords | |
# This is important for removing stopwords and stemming the query + text | |
# The default is english for both | |
stemmer=Stemmer.Stemmer("english"), | |
language="english", | |
) | |
bm25_retriever.persist(PERSIST_DIR + "/bm25_retriever") | |
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model) | |
index.storage_context.persist(persist_dir=PERSIST_DIR) | |
return index, bm25_retriever | |
else: | |
print("Loading index") | |
bm25_retriever = BM25Retriever.from_persist_dir( | |
PERSIST_DIR + "/bm25_retriever" | |
) | |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
index = load_index_from_storage(storage_context) | |
return index, bm25_retriever | |
if me.runtime().is_hot_reload_in_progress: | |
print("Hot reload - skip building index!") | |
query_engine = me._query_engine | |
bm25_retriever = me._bm25_retriever | |
else: | |
index, bm25_retriever = build_or_load_index() | |
llm = Gemini(model="models/gemini-1.5-flash-latest") | |
retriever = QueryFusionRetriever( | |
[ | |
index.as_retriever(similarity_top_k=5), | |
bm25_retriever, | |
], | |
llm=llm, | |
num_queries=1, | |
use_async=True, | |
similarity_top_k=5, | |
) | |
query_engine = CitationQueryEngine.from_args( | |
index, | |
retriever=retriever, | |
llm=llm, | |
citation_qa_template=CITATION_QA_TEMPLATE, | |
similarity_top_k=5, | |
embedding_model=embed_model, | |
streaming=True, | |
) | |
blocking_query_engine = CitationQueryEngine.from_args( | |
index, | |
retriever=retriever, | |
llm=llm, | |
citation_qa_template=CITATION_QA_TEMPLATE, | |
similarity_top_k=5, | |
embedding_model=embed_model, | |
streaming=False, | |
) | |
# TODO: replace with proper mechanism for persisting objects | |
# across hot reloads | |
me._query_engine = query_engine | |
me._bm25_retriever = bm25_retriever | |
NEWLINE = "\n" | |
def ask(query: str): | |
return query_engine.query(query) | |
def retrieve(query: str): | |
return bm25_retriever.retrieve(query) | |