nyp_chatbot / app.py
brandonongsc's picture
Upload app.py
79d5986 verified
raw
history blame contribute delete
No virus
8.47 kB
import streamlit as st
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import geocoder
from geopy.distance import geodesic
import pandas as pd
import folium
from streamlit_folium import folium_static
from transformers import pipeline
import logging
#-----------------
# demonstrating use of a Vectordb store
#-----------------
DB_FAISS_PATH = 'vectorstores/db_faiss'
#-----------------
# Detecting the context if its to be a normal textual chat, load nearest clinic map or shopping link
#-----------------
classifier = pipeline("zero-shot-classification")
#-----------------
# Set up logging. mostly for debugging purposes only
#-----------------
logging.basicConfig(filename='app.log', level=logging.DEBUG, format='%(asctime)s %(message)s')
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
prompt = PromptTemplate(template=custom_prompt_template,
input_variables=['context', 'question'])
return prompt
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
#-----------------
#function to load LLM from huggingface
#-----------------
def load_llm():
llm = CTransformers(
model="TheBloke/Llama-2-7B-Chat-GGML",
model_type="llama",
max_new_tokens=512,
temperature=0.5
)
return llm
#-----------------
#function that does 3 things
#1. loads maps using Folium if Context is nearest clinic (maps loads dataset from csv)
#2. loads a shopee link if Context is to buy things
#3. loads normal chat bubble which is to infer the chat bubble
#-----------------
def qa_bot(query, context=""):
logging.info(f"Received query: {query}, Context: {context}")
if context in ["nearest clinic","nearest TCM clinic","nearest TCM doctor","near me","nearest to me"]:
#-----------
# Loads map
#-----------
logging.info("Context matched for nearest TCM clinic.")
# Get user's current location
g = geocoder.ip('me')
user_lat, user_lon = g.latlng
# Load locations from the CSV file
locations_df = pd.read_csv("dataset/locations.csv")
# Filter locations within 5km from user's current location
filtered_locations_df = locations_df[locations_df.apply(lambda row: geodesic((user_lat, user_lon), (row['latitude'], row['longitude'])).kilometers <= 5, axis=1)]
# Create map centered at user's location
my_map = folium.Map(location=[user_lat, user_lon], zoom_start=12)
# Add markers with custom tooltips for filtered locations
for index, location in filtered_locations_df.iterrows():
folium.Marker(location=[location['latitude'], location['longitude']], tooltip=f"{location['name']}<br>Reviews: {location['Stars_review']}<br>Avg Price $: {location['Price']}<br>Contact No: {location['Contact']}").add_to(my_map)
# Display map
folium_static(my_map)
return "[Map of Clinic Locations 5km from your current location]"
elif context in ["buy", "Ointment", "Hong You", "Feng You", "Fengyou", "Po chai pills"]:
#-----------
# Loads shopee link
#-----------
logging.info("Context matched for buying.")
# Create a hyperlink to shopee.sg based on the search query
shopee_link = f"<a href='https://shopee.sg/search?keyword={context}'>at this Shopee link!</a>"
return f"You may visit this page to purchase {context} {shopee_link}!"
else:
#-----------
# Loads normal chat bubble
#-----------
logging.info("Context not matched for nearest TCM clinic or buying.")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa = retrieval_qa_chain(llm, qa_prompt, db)
# Implement the question-answering logic here
response = qa({'query': query})
return response['result']
def add_vertical_space(spaces=1):
for _ in range(spaces):
st.markdown("---")
def main():
st.set_page_config(page_title="Ask me anything about TCM")
with st.sidebar:
st.title('Welcome to Nexus AI TCM!')
st.markdown('''
<style>
[data-testid=stSidebar] {
background-color: #ffffff;
}
</style>
<img src="https://huggingface.co/spaces/mathslearn/chatbot_test_streamlit/resolve/main/logo.jpeg" width=200>
''', unsafe_allow_html=True)
add_vertical_space(1) # Adjust the number of spaces as needed
st.title("Nexus AI TCM")
st.markdown(
"""
<style>
.chat-container {
display: flex;
flex-direction: column;
height: 400px;
overflow-y: auto;
padding: 10px;
color: white; /* Font color */
}
.user-bubble {
background-color: #007bff; /* Blue color for user */
align-self: flex-end;
border-radius: 10px;
padding: 8px;
margin: 5px;
max-width: 70%;
word-wrap: break-word;
}
.bot-bubble {
background-color: #363636; /* Slightly lighter background color */
align-self: flex-start;
border-radius: 10px;
padding: 8px;
margin: 5px;
max-width: 70%;
word-wrap: break-word;
}
</style>
"""
, unsafe_allow_html=True)
conversation = st.session_state.get("conversation", [])
if "my_text" not in st.session_state:
st.session_state.my_text = ""
st.text_input("Enter text here", key="widget", on_change=submit)
query = st.session_state.my_text
if st.button("Ask"):
if query:
with st.spinner("Processing your question..."): # Display the processing message
conversation.append({"role": "user", "message": query})
# Call your QA function
answer = qa_bot(query, infer_context(query))
conversation.append({"role": "bot", "message": answer})
st.session_state.conversation = conversation
else:
st.warning("Please input a question.")
#
# Display the conversation history
chat_container = st.empty()
chat_bubbles = ''.join([f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
chat_container.markdown(f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
def submit():
st.session_state.my_text = st.session_state.widget
st.session_state.widget = ""
#-----------
# Setting the Context
#-----------
def infer_context(query):
"""
Function to infer context based on the user's query.
Modify this function to suit your context detection needs.
"""
labels = ["TCM","sick","herbs","traditional","nearest clinic","nearest TCM clinic","nearest TCM doctor","near me","nearest to me", "Ointment", "Hong You", "Feng You", "Fengyou", "Po chai pills"]
result = classifier(query, labels)
predicted_label = result["labels"][0]
return predicted_label
if __name__ == "__main__":
main()