aerospace_chatbot_ams / queries.py
dsmueller's picture
Adding rag study updates
148b409
raw
history blame
11.5 kB
import os
import logging
import re
from dotenv import load_dotenv, find_dotenv
import openai
import pinecone
import chromadb
from langchain_community.vectorstores import Pinecone
from langchain_community.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain.schema import format_document
from langchain_core.messages import get_buffer_string
from prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT, DEFAULT_DOCUMENT_PROMPT, TEST_QUERY_PROMPT
# Set secrets from environment file
OPENAI_API_KEY=os.getenv('OPENAI_API_KEY')
VOYAGE_API_KEY=os.getenv('VOYAGE_API_KEY')
PINECONE_API_KEY=os.getenv('PINECONE_API_KEY')
HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN')
# Class and functions
class QA_Model:
def __init__(self,
index_type,
index_name,
query_model,
llm,
k=6,
search_type='similarity',
fetch_k=50,
temperature=0,
chain_type='stuff',
filter_arg=False):
self.index_type=index_type
self.index_name=index_name
self.query_model=query_model
self.llm=llm
self.k=k
self.search_type=search_type
self.fetch_k=fetch_k
self.temperature=temperature
self.chain_type=chain_type
self.filter_arg=filter_arg
self.sources=[]
load_dotenv(find_dotenv(),override=True)
# Define retriever search parameters
search_kwargs = _process_retriever_args(self.filter_arg,
self.sources,
self.search_type,
self.k,
self.fetch_k)
# Read in from the vector database
if index_type=='Pinecone':
pinecone.init(
api_key=PINECONE_API_KEY
)
logging.info('Chat pinecone index name: '+str(index_name))
logging.info('Chat query model: '+str(query_model))
index = pinecone.Index(index_name)
self.vectorstore = Pinecone(index,query_model,'page_content')
logging.info('Chat vectorstore: '+str(self.vectorstore))
# Test query
test_query = self.vectorstore.similarity_search(TEST_QUERY_PROMPT)
logging.info('Test query: '+str(test_query))
if not test_query:
raise ValueError("Pinecone vector database is not configured properly. Test query failed.")
else:
logging.info('Test query succeeded!')
self.retriever=self.vectorstore.as_retriever(search_type=search_type,
search_kwargs=search_kwargs)
logging.info('Chat retriever: '+str(self.retriever))
elif index_type=='ChromaDB':
logging.info('Chat chroma index name: '+str(index_name))
logging.info('Chat query model: '+str(query_model))
persistent_client = chromadb.PersistentClient(path='../db/chromadb')
self.vectorstore = Chroma(client=persistent_client,
collection_name=index_name,
embedding_function=query_model)
logging.info('Chat vectorstore: '+str(self.vectorstore))
# Test query
test_query = self.vectorstore.similarity_search(TEST_QUERY_PROMPT)
logging.info('Test query: '+str(test_query))
if not test_query:
raise ValueError("Chroma vector database is not configured properly. Test query failed.")
else:
logging.info('Test query succeeded!')
self.retriever=self.vectorstore.as_retriever(search_type=search_type,
search_kwargs=search_kwargs)
logging.info('Chat retriever: '+str(self.retriever))
elif index_type=='RAGatouille':
# Easy because the index is picked up directly.
self.vectorstore=query_model
logging.info('Chat query model:'+str(query_model))
# Test query
test_query = self.vectorstore.search(TEST_QUERY_PROMPT)
logging.info('Test query: '+str(test_query))
if not test_query:
raise ValueError("Chroma vector database is not configured properly. Test query failed.")
else:
logging.info('Test query succeeded!')
self.retriever=self.vectorstore.as_langchain_retriever()
logging.info('Chat retriever: '+str(self.retriever))
# Intialize memory
self.memory = ConversationBufferMemory(
return_messages=True, output_key='answer', input_key='question')
logging.info('Memory: '+str(self.memory))
# Assemble main chain
self.conversational_qa_chain=_define_qa_chain(self.llm,
self.retriever,
self.memory,
self.search_type,
search_kwargs)
def query_docs(self,query):
self.memory.load_memory_variables({})
logging.info('Memory content before qa result: '+str(self.memory))
logging.info('Query: '+str(query))
self.result = self.conversational_qa_chain.invoke({'question': query})
logging.info('QA result: '+str(self.result))
if self.index_type!='RAGatouille':
self.sources = '\n'.join(str(data.metadata) for data in self.result['references'])
self.result['answer'].content += '\nSources: \n'+self.sources
logging.info('Sources: '+str(self.sources))
logging.info('Response with sources: '+str(self.result['answer'].content))
else:
# RAGatouille doesn't have metadata, need to extract from context first.
extracted_metadata = []
pattern = r'\{([^}]*)\}(?=[^{}]*$)' # Regular expression pattern to match the last curly braces
for ref in self.result['references']:
match = re.search(pattern, ref.page_content)
if match:
extracted_metadata.append("{"+match.group(1)+"}")
self.sources = '\n'.join(extracted_metadata)
self.result['answer'].content += '\nSources: \n'+self.sources
logging.info('Sources: '+str(self.sources))
logging.info('Response with sources: '+str(self.result['answer'].content))
self.memory.save_context({'question': query}, {'answer': self.result['answer'].content})
logging.info('Memory content after qa result: '+str(self.memory))
def update_model(self,
llm,
k=6,
search_type='similarity',
fetch_k=50,
filter_arg=False):
self.llm=llm
self.k=k
self.search_type=search_type
self.fetch_k=fetch_k
self.filter_arg=filter_arg
# Define retriever search parameters
search_kwargs = _process_retriever_args(self.filter_arg,
self.sources,
self.search_type,
self.k,
self.fetch_k)
# Update conversational retrieval chain
self.conversational_qa_chain=_define_qa_chain(self.llm,
self.retriever,
self.memory,
self.search_type,
search_kwargs)
logging.info('Updated qa chain: '+str(self.conversational_qa_chain))
# Internal functions
def _combine_documents(docs,
document_prompt=DEFAULT_DOCUMENT_PROMPT,
document_separator='\n\n'):
'''
Combine a list of documents into a single string.
'''
# TODO: this would be where stuff, map reduce, etc. would go
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def _define_qa_chain(llm,
retriever,
memory,
search_type,
search_kwargs):
'''
Define the conversational QA chain.
'''
# This adds a 'memory' key to the input object
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(memory.load_memory_variables)
| itemgetter('history'))
logging.info('Loaded memory: '+str(loaded_memory))
# Assemble main chain
standalone_question = {
'standalone_question': {
'question': lambda x: x['question'],
'chat_history': lambda x: get_buffer_string(x['chat_history'])}
| CONDENSE_QUESTION_PROMPT
| llm
| StrOutputParser()}
logging.info('Condense inputs as a standalong question: '+str(standalone_question))
retrieved_documents = {
'source_documents': itemgetter('standalone_question')
| retriever,
'question': lambda x: x['standalone_question']}
logging.info('Retrieved documents: '+str(retrieved_documents))
# Now we construct the inputs for the final prompt
final_inputs = {
'context': lambda x: _combine_documents(x['source_documents']),
'question': itemgetter('question')}
logging.info('Combined documents: '+str(final_inputs))
# And finally, we do the part that returns the answers
answer = {
'answer': final_inputs
| QA_PROMPT
| llm,
'references': itemgetter('source_documents')}
conversational_qa_chain = loaded_memory | standalone_question | retrieved_documents | answer
logging.info('Conversational QA chain: '+str(conversational_qa_chain))
return conversational_qa_chain
def _process_retriever_args(filter_arg,
sources,
search_type,
k,
fetch_k):
'''
Process arguments for retriever.
'''
# Implement filter
if filter_arg:
filter_list = list(set(item['source'] for item in sources[-1]))
filter_items=[]
for item in filter_list:
filter_item={'source': item}
filter_items.append(filter_item)
filter={'$or':filter_items}
else:
filter=None
# Impement filtering and number of documents to return
if search_type=='mmr':
search_kwargs={'k':k,'fetch_k':fetch_k,'filter':filter} # See as_retriever docs for parameters
else:
search_kwargs={'k':k,'filter':filter} # See as_retriever docs for parameters
return search_kwargs