Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings | |
import pymongo | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.messages import get_buffer_string | |
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.messages import HumanMessage | |
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None) | |
db = FAISS.load_local("vms_faiss_index", embedder, allow_dangerous_deserialization=True) | |
# docs = new_db.similarity_search(query) | |
nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "") | |
def get_mongo_client(mongo_uri): | |
"""Establish connection to the MongoDB.""" | |
try: | |
client = pymongo.MongoClient(mongo_uri) | |
print("Connection to MongoDB successful") | |
return client | |
except pymongo.errors.ConnectionFailure as e: | |
print(f"Connection failed: {e}") | |
return None | |
mongo_uri = os.environ.get('MyCluster_MONGO_URI') | |
if not mongo_uri: | |
print("MONGO_URI not set in environment variables") | |
mongo_client = get_mongo_client(mongo_uri) | |
DB_NAME="vms_courses" | |
COLLECTION_NAME="courses" | |
db = mongo_client[DB_NAME] | |
collection = db[COLLECTION_NAME] | |
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index" | |
vector_search = MongoDBAtlasVectorSearch.from_connection_string( | |
mongo_uri, | |
DB_NAME + "." + COLLECTION_NAME, | |
embedder, | |
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, | |
) | |
llm = ChatNVIDIA(model="mixtral_8x7b") | |
retriever = vector_search.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 12}, | |
) | |
### Contextualize question ### | |
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
which might reference context in the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is.""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
history_aware_retriever = create_history_aware_retriever( | |
llm, retriever, contextualize_q_prompt | |
) | |
### Answer question ### | |
qa_system_prompt = """You are a VMS assistant for helping students with their academic. \ | |
Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer. \ | |
Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University. \ | |
Do not hallucinate any details, and make sure the knowledge base is not redundant.\ | |
If you don't know the answer, just say that you don't know. \ | |
{context}""" | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
### Statefully manage chat history ### | |
store = {} | |
def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
c_history = [] | |
def chat_gen(message, history): | |
buffer = "" | |
ai_message = rag_chain.invoke({"input": message, "chat_history": c_history}) | |
c_history.extend([HumanMessage(content=message), ai_message["answer"]]) | |
print(c_history) | |
yield ai_message["answer"] | |
# for doc in ai_message["context"]: | |
# yield doc | |
initial_msg = ( | |
"Hello! I am VMS bot here to help you with your academic issues!" | |
f"\nHow can I help you?" | |
) | |
chatbot = gr.Chatbot(value = [[None, initial_msg]], bubble_full_width=False) | |
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() | |
try: | |
demo.launch(debug=True, share=True, show_api=False) | |
demo.close() | |
except Exception as e: | |
demo.close() | |
print(e) | |
raise e | |
# available models names | |
# mixtral_8x7b | |
# llama2_13b | |
# llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() | |
# initial_msg = ( | |
# "Hello! I am VMS bot here to help you with your academic issues!" | |
# f"\nHow can I help you?" | |
# ) | |
# context_prompt = ChatPromptTemplate.from_messages([ | |
# ('system', | |
# "You are a VMS chatbot, and you are helping students with their academic issues." | |
# "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer." | |
# "Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University." | |
# "Do not hallucinate any details, and make sure the knowledge base is not redundant." | |
# "Please say you do not know if you do not know or you cannot find the information needed." | |
# "\n\nQuestion: {question}\n\nContext: {context}"), | |
# ('user', "{question}" | |
# )]) | |
# chain = ( | |
# { | |
# 'context': db.as_retriever(search_type="similarity"), | |
# 'question': (lambda x:x) | |
# } | |
# | context_prompt | |
# # | RPrint() | |
# | llm | |
# | StrOutputParser() | |
# ) | |
# conv_chain = ( | |
# context_prompt | |
# # | RPrint() | |
# | llm | |
# | StrOutputParser() | |
# ) | |
# def chat_gen(message, history, return_buffer=True): | |
# buffer = "" | |
# doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2}) | |
# retrieved_docs = doc_retriever.invoke(message) | |
# print(len(retrieved_docs)) | |
# print(retrieved_docs) | |
# if len(retrieved_docs) > 0: | |
# state = { | |
# 'question': message, | |
# 'context': retrieved_docs | |
# } | |
# for token in conv_chain.stream(state): | |
# buffer += token | |
# yield buffer | |
# else: | |
# passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question." | |
# buffer += passage | |
# yield buffer if return_buffer else passage | |
# chatbot = gr.Chatbot(value = [[None, initial_msg]]) | |
# iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() | |
# iface.launch() |