import os import gradio as gr from langchain_community.vectorstores import FAISS from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings import pymongo from langchain_community.vectorstores import MongoDBAtlasVectorSearch from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough from langchain.memory import ConversationBufferMemory from langchain_core.messages import get_buffer_string from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain_core.chat_history import BaseChatMessageHistory from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.messages import HumanMessage embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None) db = FAISS.load_local("vms_faiss_index", embedder, allow_dangerous_deserialization=True) # docs = new_db.similarity_search(query) nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "") def get_mongo_client(mongo_uri): """Establish connection to the MongoDB.""" try: client = pymongo.MongoClient(mongo_uri) print("Connection to MongoDB successful") return client except pymongo.errors.ConnectionFailure as e: print(f"Connection failed: {e}") return None mongo_uri = os.environ.get('MyCluster_MONGO_URI') if not mongo_uri: print("MONGO_URI not set in environment variables") mongo_client = get_mongo_client(mongo_uri) DB_NAME="vms_courses" COLLECTION_NAME="courses" db = mongo_client[DB_NAME] collection = db[COLLECTION_NAME] ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index" vector_search = MongoDBAtlasVectorSearch.from_connection_string( mongo_uri, DB_NAME + "." + COLLECTION_NAME, embedder, index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, ) llm = ChatNVIDIA(model="mixtral_8x7b") retriever = vector_search.as_retriever( search_type="similarity", search_kwargs={"k": 12}, ) ### Contextualize question ### contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) ### Answer question ### qa_system_prompt = """You are a VMS assistant for helping students with their academic. \ Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer. \ Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University. \ Do not hallucinate any details, and make sure the knowledge base is not redundant.\ If you don't know the answer, just say that you don't know. \ {context}""" qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) ### Statefully manage chat history ### store = {} def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) c_history = [] def chat_gen(message, history): buffer = "" ai_message = rag_chain.invoke({"input": message, "chat_history": c_history}) c_history.extend([HumanMessage(content=message), ai_message["answer"]]) print(c_history) yield ai_message["answer"] # for doc in ai_message["context"]: # yield doc initial_msg = ( "Hello! I am VMS bot here to help you with your academic issues!" f"\nHow can I help you?" ) chatbot = gr.Chatbot(value = [[None, initial_msg]], bubble_full_width=False) demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() try: demo.launch(debug=True, share=True, show_api=False) demo.close() except Exception as e: demo.close() print(e) raise e # available models names # mixtral_8x7b # llama2_13b # llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() # initial_msg = ( # "Hello! I am VMS bot here to help you with your academic issues!" # f"\nHow can I help you?" # ) # context_prompt = ChatPromptTemplate.from_messages([ # ('system', # "You are a VMS chatbot, and you are helping students with their academic issues." # "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer." # "Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University." # "Do not hallucinate any details, and make sure the knowledge base is not redundant." # "Please say you do not know if you do not know or you cannot find the information needed." # "\n\nQuestion: {question}\n\nContext: {context}"), # ('user', "{question}" # )]) # chain = ( # { # 'context': db.as_retriever(search_type="similarity"), # 'question': (lambda x:x) # } # | context_prompt # # | RPrint() # | llm # | StrOutputParser() # ) # conv_chain = ( # context_prompt # # | RPrint() # | llm # | StrOutputParser() # ) # def chat_gen(message, history, return_buffer=True): # buffer = "" # doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2}) # retrieved_docs = doc_retriever.invoke(message) # print(len(retrieved_docs)) # print(retrieved_docs) # if len(retrieved_docs) > 0: # state = { # 'question': message, # 'context': retrieved_docs # } # for token in conv_chain.stream(state): # buffer += token # yield buffer # else: # passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question." # buffer += passage # yield buffer if return_buffer else passage # chatbot = gr.Chatbot(value = [[None, initial_msg]]) # iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() # iface.launch()