brandonongsc commited on
Commit
4bc5561
1 Parent(s): b45dd59

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +167 -0
model.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+ import geocoder
9
+ from geopy.distance import geodesic #help calc nearest clinic
10
+ import pandas as pd
11
+ import folium #using folium maps allows us to show richer details on the map like tooltips
12
+ from streamlit_folium import folium_static
13
+
14
+ DB_FAISS_PATH = 'vectorstores/db_faiss'
15
+
16
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
17
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
18
+
19
+ Context: {context}
20
+ Question: {question}
21
+
22
+ Only return the helpful answer below and nothing else.
23
+ Helpful answer:
24
+ """
25
+
26
+ def set_custom_prompt():
27
+ prompt = PromptTemplate(template=custom_prompt_template,
28
+ input_variables=['context', 'question'])
29
+ return prompt
30
+
31
+ def retrieval_qa_chain(llm, prompt, db):
32
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
33
+ chain_type='stuff',
34
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
35
+ return_source_documents=True,
36
+ chain_type_kwargs={'prompt': prompt}
37
+ )
38
+ return qa_chain
39
+
40
+ def load_llm():
41
+ llm = CTransformers(
42
+ model="TheBloke/Llama-2-7B-Chat-GGML",
43
+ model_type="llama",
44
+ max_new_tokens=512,
45
+ temperature=0.5
46
+ )
47
+ return llm
48
+
49
+ import folium
50
+
51
+ def qa_bot(query):
52
+ if 'nearest TCM clinic' in query:
53
+ # Get user's current location
54
+ g = geocoder.ip('me')
55
+ user_lat, user_lon = g.latlng
56
+
57
+ # Load locations from the CSV file
58
+ locations_df = pd.read_csv("dataset/locations.csv")
59
+
60
+ # Filter locations within 5km from user's current location
61
+ filtered_locations_df = locations_df[locations_df.apply(lambda row: geodesic((user_lat, user_lon), (row['latitude'], row['longitude'])).kilometers <= 5, axis=1)]
62
+
63
+ # Create map centered at user's location
64
+ my_map = folium.Map(location=[user_lat, user_lon], zoom_start=12)
65
+
66
+ # Add markers with custom tooltips for filtered locations
67
+ for index, location in filtered_locations_df.iterrows():
68
+ folium.Marker(location=[location['latitude'], location['longitude']], tooltip=f"{location['name']}<br>Reviews: {location['Stars_review']}<br>Avg Price $: {location['Price']}<br>Contact No: {location['Contact']}").add_to(my_map)
69
+
70
+ # Display map
71
+ folium_static(my_map)
72
+
73
+ return "Displaying locations within 5km from your current location."
74
+
75
+ else:
76
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
77
+ model_kwargs={'device': 'cpu'})
78
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
79
+ llm = load_llm()
80
+ qa_prompt = set_custom_prompt()
81
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
82
+
83
+ # Implement the question-answering logic here
84
+ response = qa({'query': query})
85
+ return response['result']
86
+
87
+
88
+
89
+
90
+
91
+ def add_vertical_space(spaces=1):
92
+ for _ in range(spaces):
93
+ st.markdown("---")
94
+
95
+ def main():
96
+ st.set_page_config(page_title="Ask me anything about TCM")
97
+
98
+ with st.sidebar:
99
+ st.title('Welcome to Nexus AI TCM!')
100
+
101
+ st.markdown('''
102
+
103
+ <style>
104
+ [data-testid=stSidebar] {
105
+ background-color: #ffffff;
106
+ }
107
+ </style>
108
+ <img src="http://40.90.239.142/bongvm/img/nexus_logo4.png" width=200>
109
+
110
+ ''', unsafe_allow_html=True)
111
+ add_vertical_space(1) # Adjust the number of spaces as needed
112
+ #st.write('Made by [@ThisIs-Developer](https://huggingface.co/ThisIs-Developer)')
113
+
114
+ st.title("Nexus AI TCM")
115
+ st.markdown(
116
+ """
117
+ <style>
118
+ .chat-container {
119
+ display: flex;
120
+ flex-direction: column;
121
+ height: 400px;
122
+ overflow-y: auto;
123
+ padding: 10px;
124
+ color: white; /* Font color */
125
+ }
126
+ .user-bubble {
127
+ background-color: #007bff; /* Blue color for user */
128
+ align-self: flex-end;
129
+ border-radius: 10px;
130
+ padding: 8px;
131
+ margin: 5px;
132
+ max-width: 70%;
133
+ word-wrap: break-word;
134
+ }
135
+ .bot-bubble {
136
+ background-color: #363636; /* Slightly lighter background color */
137
+ align-self: flex-start;
138
+ border-radius: 10px;
139
+ padding: 8px;
140
+ margin: 5px;
141
+ max-width: 70%;
142
+ word-wrap: break-word;
143
+ }
144
+ </style>
145
+ """
146
+ , unsafe_allow_html=True)
147
+
148
+ conversation = st.session_state.get("conversation", [])
149
+
150
+ query = st.text_input("Ask your question here:", key="user_input")
151
+ if st.button("Get Answer"):
152
+ if query:
153
+ with st.spinner("Processing your question..."): # Display the processing message
154
+ conversation.append({"role": "user", "message": query})
155
+ # Call your QA function
156
+ answer = qa_bot(query)
157
+ conversation.append({"role": "bot", "message": answer})
158
+ st.session_state.conversation = conversation
159
+ else:
160
+ st.warning("Please input a question.")
161
+
162
+ chat_container = st.empty()
163
+ chat_bubbles = ''.join([f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
164
+ chat_container.markdown(f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
165
+
166
+ if __name__ == "__main__":
167
+ main()