brandonongsc commited on
Commit
79d5986
1 Parent(s): c3b45b1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+ import geocoder
9
+ from geopy.distance import geodesic
10
+ import pandas as pd
11
+ import folium
12
+ from streamlit_folium import folium_static
13
+ from transformers import pipeline
14
+ import logging
15
+
16
+ #-----------------
17
+ # demonstrating use of a Vectordb store
18
+ #-----------------
19
+
20
+ DB_FAISS_PATH = 'vectorstores/db_faiss'
21
+
22
+ #-----------------
23
+ # Detecting the context if its to be a normal textual chat, load nearest clinic map or shopping link
24
+ #-----------------
25
+ classifier = pipeline("zero-shot-classification")
26
+
27
+ #-----------------
28
+ # Set up logging. mostly for debugging purposes only
29
+ #-----------------
30
+ logging.basicConfig(filename='app.log', level=logging.DEBUG, format='%(asctime)s %(message)s')
31
+
32
+
33
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
34
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
35
+
36
+ Context: {context}
37
+ Question: {question}
38
+
39
+ Only return the helpful answer below and nothing else.
40
+ Helpful answer:
41
+ """
42
+
43
+ def set_custom_prompt():
44
+ prompt = PromptTemplate(template=custom_prompt_template,
45
+ input_variables=['context', 'question'])
46
+ return prompt
47
+
48
+ def retrieval_qa_chain(llm, prompt, db):
49
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
50
+ chain_type='stuff',
51
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
52
+ return_source_documents=True,
53
+ chain_type_kwargs={'prompt': prompt}
54
+ )
55
+ return qa_chain
56
+
57
+
58
+ #-----------------
59
+ #function to load LLM from huggingface
60
+ #-----------------
61
+ def load_llm():
62
+ llm = CTransformers(
63
+
64
+ model="TheBloke/Llama-2-7B-Chat-GGML",
65
+ model_type="llama",
66
+ max_new_tokens=512,
67
+ temperature=0.5
68
+ )
69
+ return llm
70
+
71
+
72
+ #-----------------
73
+ #function that does 3 things
74
+ #1. loads maps using Folium if Context is nearest clinic (maps loads dataset from csv)
75
+ #2. loads a shopee link if Context is to buy things
76
+ #3. loads normal chat bubble which is to infer the chat bubble
77
+ #-----------------
78
+
79
+ def qa_bot(query, context=""):
80
+ logging.info(f"Received query: {query}, Context: {context}")
81
+
82
+
83
+ if context in ["nearest clinic","nearest TCM clinic","nearest TCM doctor","near me","nearest to me"]:
84
+ #-----------
85
+ # Loads map
86
+ #-----------
87
+
88
+ logging.info("Context matched for nearest TCM clinic.")
89
+ # Get user's current location
90
+ g = geocoder.ip('me')
91
+ user_lat, user_lon = g.latlng
92
+
93
+ # Load locations from the CSV file
94
+ locations_df = pd.read_csv("dataset/locations.csv")
95
+
96
+ # Filter locations within 5km from user's current location
97
+ filtered_locations_df = locations_df[locations_df.apply(lambda row: geodesic((user_lat, user_lon), (row['latitude'], row['longitude'])).kilometers <= 5, axis=1)]
98
+
99
+ # Create map centered at user's location
100
+ my_map = folium.Map(location=[user_lat, user_lon], zoom_start=12)
101
+
102
+ # Add markers with custom tooltips for filtered locations
103
+ for index, location in filtered_locations_df.iterrows():
104
+ folium.Marker(location=[location['latitude'], location['longitude']], tooltip=f"{location['name']}<br>Reviews: {location['Stars_review']}<br>Avg Price $: {location['Price']}<br>Contact No: {location['Contact']}").add_to(my_map)
105
+
106
+ # Display map
107
+ folium_static(my_map)
108
+
109
+ return "[Map of Clinic Locations 5km from your current location]"
110
+
111
+ elif context in ["buy", "Ointment", "Hong You", "Feng You", "Fengyou", "Po chai pills"]:
112
+ #-----------
113
+ # Loads shopee link
114
+ #-----------
115
+
116
+ logging.info("Context matched for buying.")
117
+ # Create a hyperlink to shopee.sg based on the search query
118
+ shopee_link = f"<a href='https://shopee.sg/search?keyword={context}'>at this Shopee link!</a>"
119
+ return f"You may visit this page to purchase {context} {shopee_link}!"
120
+
121
+ else:
122
+ #-----------
123
+ # Loads normal chat bubble
124
+ #-----------
125
+ logging.info("Context not matched for nearest TCM clinic or buying.")
126
+
127
+
128
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
129
+ model_kwargs={'device': 'cpu'})
130
+
131
+
132
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
133
+ llm = load_llm()
134
+ qa_prompt = set_custom_prompt()
135
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
136
+
137
+ # Implement the question-answering logic here
138
+ response = qa({'query': query})
139
+ return response['result']
140
+
141
+
142
+
143
+
144
+
145
+ def add_vertical_space(spaces=1):
146
+ for _ in range(spaces):
147
+ st.markdown("---")
148
+
149
+
150
+
151
+ def main():
152
+ st.set_page_config(page_title="Ask me anything about TCM")
153
+
154
+ with st.sidebar:
155
+ st.title('Welcome to Nexus AI TCM!')
156
+
157
+ st.markdown('''
158
+
159
+ <style>
160
+ [data-testid=stSidebar] {
161
+ background-color: #ffffff;
162
+ }
163
+ </style>
164
+ <img src="https://huggingface.co/spaces/mathslearn/chatbot_test_streamlit/resolve/main/logo.jpeg" width=200>
165
+
166
+ ''', unsafe_allow_html=True)
167
+ add_vertical_space(1) # Adjust the number of spaces as needed
168
+
169
+ st.title("Nexus AI TCM")
170
+
171
+
172
+ st.markdown(
173
+ """
174
+ <style>
175
+ .chat-container {
176
+ display: flex;
177
+ flex-direction: column;
178
+ height: 400px;
179
+ overflow-y: auto;
180
+ padding: 10px;
181
+ color: white; /* Font color */
182
+ }
183
+ .user-bubble {
184
+ background-color: #007bff; /* Blue color for user */
185
+ align-self: flex-end;
186
+ border-radius: 10px;
187
+ padding: 8px;
188
+ margin: 5px;
189
+ max-width: 70%;
190
+ word-wrap: break-word;
191
+ }
192
+ .bot-bubble {
193
+ background-color: #363636; /* Slightly lighter background color */
194
+ align-self: flex-start;
195
+ border-radius: 10px;
196
+ padding: 8px;
197
+ margin: 5px;
198
+ max-width: 70%;
199
+ word-wrap: break-word;
200
+ }
201
+ </style>
202
+ """
203
+ , unsafe_allow_html=True)
204
+
205
+ conversation = st.session_state.get("conversation", [])
206
+
207
+
208
+ if "my_text" not in st.session_state:
209
+ st.session_state.my_text = ""
210
+
211
+
212
+ st.text_input("Enter text here", key="widget", on_change=submit)
213
+ query = st.session_state.my_text
214
+
215
+
216
+ if st.button("Ask"):
217
+ if query:
218
+ with st.spinner("Processing your question..."): # Display the processing message
219
+ conversation.append({"role": "user", "message": query})
220
+ # Call your QA function
221
+ answer = qa_bot(query, infer_context(query))
222
+ conversation.append({"role": "bot", "message": answer})
223
+ st.session_state.conversation = conversation
224
+
225
+
226
+
227
+
228
+ else:
229
+ st.warning("Please input a question.")
230
+ #
231
+
232
+
233
+ # Display the conversation history
234
+ chat_container = st.empty()
235
+ chat_bubbles = ''.join([f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
236
+ chat_container.markdown(f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
237
+
238
+
239
+
240
+ def submit():
241
+ st.session_state.my_text = st.session_state.widget
242
+ st.session_state.widget = ""
243
+
244
+
245
+ #-----------
246
+ # Setting the Context
247
+ #-----------
248
+ def infer_context(query):
249
+ """
250
+ Function to infer context based on the user's query.
251
+ Modify this function to suit your context detection needs.
252
+ """
253
+ labels = ["TCM","sick","herbs","traditional","nearest clinic","nearest TCM clinic","nearest TCM doctor","near me","nearest to me", "Ointment", "Hong You", "Feng You", "Fengyou", "Po chai pills"]
254
+ result = classifier(query, labels)
255
+ predicted_label = result["labels"][0]
256
+ return predicted_label
257
+
258
+
259
+ if __name__ == "__main__":
260
+ main()