OpenWormLLM / corpus_query.py
pgleeson's picture
Formatting updates
988a46d
from llama_index.core.vector_stores import SimpleVectorStore
from llama_index.core.storage.storage_context import StorageContext
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core import load_index_from_storage
from llama_index.core import PromptTemplate
from llama_index.core import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
STORE_DIR = "openworm.ai_store"
SOURCE_DOCUMENT = "source document"
LLM_GPT4o = "GPT4o"
def print_(text):
print(text)
def load_index(model):
OLLAMA_MODEL = model.replace("Ollama:", "") if model is not LLM_GPT4o else None
print_("Creating a storage context for %s" % model)
STORE_SUBFOLDER = (
"" if OLLAMA_MODEL is None else "/%s" % OLLAMA_MODEL.replace(":", "_")
)
# index_reloaded =SimpleIndexStore.from_persist_dir(persist_dir=INDEX_STORE_DIR)
storage_context = StorageContext.from_defaults(
docstore=SimpleDocumentStore.from_persist_dir(
persist_dir=STORE_DIR + STORE_SUBFOLDER
),
vector_store=SimpleVectorStore.from_persist_dir(
persist_dir=STORE_DIR + STORE_SUBFOLDER
),
index_store=SimpleIndexStore.from_persist_dir(
persist_dir=STORE_DIR + STORE_SUBFOLDER
),
)
print_("Reloading index for %s" % model)
index_reloaded = load_index_from_storage(storage_context)
return index_reloaded
def get_query_engine(index_reloaded, model, similarity_top_k=4):
OLLAMA_MODEL = model.replace("Ollama:", "") if model is not LLM_GPT4o else None
print_("Creating query engine for %s" % model)
# Based on: https://docs.llamaindex.ai/en/stable/examples/customization/prompts/completion_prompts/
text_qa_template_str = (
"Context information is"
" below.\n---------------------\n{context_str}\n---------------------\nUsing"
" both the context information and also using your own knowledge, answer"
" the question: {query_str}\nIf the context isn't helpful, you can also"
" answer the question on your own.\n"
)
text_qa_template = PromptTemplate(text_qa_template_str)
refine_template_str = (
"The original question is as follows: {query_str}\nWe have provided an"
" existing answer: {existing_answer}\nWe have the opportunity to refine"
" the existing answer (only if needed) with some more context"
" below.\n------------\n{context_msg}\n------------\nUsing both the new"
" context and your own knowledge, update or repeat the existing answer.\n"
)
refine_template = PromptTemplate(refine_template_str)
# create a query engine for the index
if OLLAMA_MODEL is not None:
llm = Ollama(model=OLLAMA_MODEL)
ollama_embedding = OllamaEmbedding(
model_name=OLLAMA_MODEL,
)
query_engine = index_reloaded.as_query_engine(
llm=llm,
text_qa_template=text_qa_template,
refine_template=refine_template,
embed_model=ollama_embedding,
)
query_engine.retriever.similarity_top_k = similarity_top_k
else: # use OpenAI...
# configure retriever
retriever = VectorIndexRetriever(
index=index_reloaded,
similarity_top_k=similarity_top_k,
)
# configure response synthesizer
response_synthesizer = get_response_synthesizer(
response_mode="refine",
text_qa_template=text_qa_template,
refine_template=refine_template,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
)
return query_engine
llm_ver = LLM_GPT4o
index_reloaded = load_index(llm_ver)
query_engine = get_query_engine(index_reloaded, llm_ver)
def process_query(query, model=llm_ver):
response = query_engine.query(query)
response_text = str(response)
if "<think>" in response_text: # Give deepseek a fighting chance...
response_text = (
response_text[0 : response_text.index("<think>")]
+ response_text[response_text.index("</think>") + 8 :]
)
metadata = response.metadata
cutoff = 0.2
files_used = []
for sn in response.source_nodes:
# print(sn)
sd = sn.metadata["source document"]
if "et_al_" in sd:
sd = sd.replace("WormAtlas Handbook:", "Paper: ")
if sd not in files_used:
if len(files_used) == 0 or sn.score >= cutoff:
files_used.append(f"{sd} (score: {sn.score})")
file_info = ",\n ".join(files_used)
print_(f"""
===============================================================================
QUERY: {query}
MODEL: {model}
-------------------------------------------------------------------------------
RESPONSE: {response_text}
SOURCES:
{file_info}
===============================================================================
""")
return response_text, metadata
def run_query(query):
response_text, metadata = process_query(query, "GPT4o")
files_used = []
for k in metadata:
v = metadata[k]
if SOURCE_DOCUMENT in v:
if v[SOURCE_DOCUMENT] not in files_used:
sd = v[SOURCE_DOCUMENT]
if "et_al_" in sd:
sd = sd.replace("WormAtlas Handbook:", "Paper: ")
files_used.append(sd)
srcs = "\n- ".join(files_used)
answer = f"""{response_text}
SOURCES OF ANSWER:
- {srcs}"""
return answer
if __name__ == "__main__":
print("Running queries")
queries = [
"What are the main types of neurons and muscles in the C. elegans pharynx?",
"Tell me about the egg laying apparatus",
]
for query in queries:
print("-------------------")
print("Q: %s" % query)
# answer = docs.query(query)
answer = run_query(query)
print("A: %s" % answer)