Spaces:
Running
Running
File size: 5,393 Bytes
d26c057 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import sys
import nest_asyncio
import Stemmer
from llama_index.core import (
PromptTemplate,
Settings,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.query_engine import CitationQueryEngine
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.schema import NodeWithScore as NodeWithScore
from llama_index.embeddings.google import GeminiEmbedding
from llama_index.llms.gemini import Gemini
from llama_index.retrievers.bm25 import BM25Retriever
import mesop as me
nest_asyncio.apply()
CITATION_QA_TEMPLATE = PromptTemplate(
"Please provide an answer based solely on the provided sources. "
"When referencing information from a source, "
"cite the appropriate source(s) using their corresponding numbers. "
"Every answer should include at least one source citation. "
"Only cite a source when you are explicitly referencing it. "
"If you are sure NONE of the sources are helpful, then say: 'Sorry, I didn't find any docs about this.'"
"If you are not sure if any of the sources are helpful, then say: 'You might find this helpful', where 'this' is the source's title.'"
"DO NOT say Source 1, Source 2, etc. Only reference sources like this: [1], [2], etc."
"I want you to pick just ONE source to answer the question."
"For example:\n"
"Source 1:\n"
"The sky is red in the evening and blue in the morning.\n"
"Source 2:\n"
"Water is wet when the sky is red.\n"
"Query: When is water wet?\n"
"Answer: Water will be wet when the sky is red [2], "
"which occurs in the evening [1].\n"
"Now it's your turn. Below are several numbered sources of information:"
"\n------\n"
"{context_str}"
"\n------\n"
"Query: {query_str}\n"
"Answer: "
)
os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"]
def get_meta(file_path: str) -> dict[str, str]:
with open(file_path) as f:
title = f.readline().strip()
if title.startswith("# "):
title = title[2:]
else:
title = (
file_path.split("/")[-1]
.replace(".md", "")
.replace("-", " ")
.capitalize()
)
file_path = file_path.replace(".md", "")
CONST = "../../docs/"
docs_index = file_path.index(CONST)
docs_path = file_path[docs_index + len(CONST) :]
url = "https://mesop-dev.github.io/mesop/" + docs_path
print(f"URL: {url}")
return {
"url": url,
"title": title,
}
embed_model = GeminiEmbedding(
model_name="models/text-embedding-004", api_key=os.environ["GOOGLE_API_KEY"]
)
Settings.embed_model = embed_model
PERSIST_DIR = "./gen"
def build_or_load_index():
if not os.path.exists(PERSIST_DIR) or "--build-index" in sys.argv:
print("Building index")
documents = SimpleDirectoryReader(
"../../docs/",
required_exts=[
".md",
],
exclude=[
"showcase.md",
"demo.md",
"blog",
"internal",
],
file_metadata=get_meta,
recursive=True,
).load_data()
for doc in documents:
doc.excluded_llm_metadata_keys = ["url"]
splitter = SentenceSplitter(chunk_size=512)
nodes = splitter.get_nodes_from_documents(documents)
bm25_retriever = BM25Retriever.from_defaults(
nodes=nodes,
similarity_top_k=5,
# Optional: We can pass in the stemmer and set the language for stopwords
# This is important for removing stopwords and stemming the query + text
# The default is english for both
stemmer=Stemmer.Stemmer("english"),
language="english",
)
bm25_retriever.persist(PERSIST_DIR + "/bm25_retriever")
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
index.storage_context.persist(persist_dir=PERSIST_DIR)
return index, bm25_retriever
else:
print("Loading index")
bm25_retriever = BM25Retriever.from_persist_dir(
PERSIST_DIR + "/bm25_retriever"
)
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
return index, bm25_retriever
if me.runtime().is_hot_reload_in_progress:
print("Hot reload - skip building index!")
query_engine = me._query_engine
bm25_retriever = me._bm25_retriever
else:
index, bm25_retriever = build_or_load_index()
llm = Gemini(model="models/gemini-1.5-flash-latest")
retriever = QueryFusionRetriever(
[
index.as_retriever(similarity_top_k=5),
bm25_retriever,
],
llm=llm,
num_queries=1,
use_async=True,
similarity_top_k=5,
)
query_engine = CitationQueryEngine.from_args(
index,
retriever=retriever,
llm=llm,
citation_qa_template=CITATION_QA_TEMPLATE,
similarity_top_k=5,
embedding_model=embed_model,
streaming=True,
)
blocking_query_engine = CitationQueryEngine.from_args(
index,
retriever=retriever,
llm=llm,
citation_qa_template=CITATION_QA_TEMPLATE,
similarity_top_k=5,
embedding_model=embed_model,
streaming=False,
)
# TODO: replace with proper mechanism for persisting objects
# across hot reloads
me._query_engine = query_engine
me._bm25_retriever = bm25_retriever
NEWLINE = "\n"
def ask(query: str):
return query_engine.query(query)
def retrieve(query: str):
return bm25_retriever.retrieve(query)
|