Goodnight7 commited on
Commit
d9bfe6a
Β·
verified Β·
1 Parent(s): 4de9a89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py CHANGED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain import memory as lc_memory
3
+ from langsmith import Client
4
+ from streamlit_feedback import streamlit_feedback
5
+ #from utils import get_expression_chain, retriever, get_embeddings, create_qdrant_collection
6
+ from langchain_core.tracers.context import collect_runs
7
+ from qdrant_client import QdrantClient
8
+ from dotenv import load_dotenv
9
+ import os
10
+ if "access_granted" not in st.session_state:
11
+ st.session_state.access_granted = False
12
+ if "profile" not in st.session_state:
13
+ st.session_state.profile = None
14
+ if "name" not in st.session_state:
15
+ st.session_state.name = None
16
+ if not st.session_state.access_granted:
17
+ # Profile input section
18
+ st.title("User Profile")
19
+ name = st.text_input("Name")
20
+ profile_selector = st.selectbox("Profile", options=["Patient", "Doctor"] )
21
+
22
+ profile = profile_selector
23
+ if profile and name:
24
+ d = False
25
+ else:
26
+ d = True
27
+
28
+ submission = st.button("Submit", disabled=d)
29
+
30
+ if submission:
31
+ st.session_state.profile = profile
32
+ st.session_state.name = name
33
+ st.session_state.access_granted = True # Grant access to main app
34
+ st.rerun() # Reload the app
35
+ else:
36
+ load_dotenv()
37
+ profile = st.session_state.profile
38
+ client = Client()
39
+ qdrant_api=os.getenv("QDRANT_API_KEY")
40
+ qdrant_url=os.getenv("QDRANT_URL")
41
+ qdrant_client = QdrantClient(qdrant_url ,api_key=qdrant_api)
42
+ st.set_page_config(page_title = "MEDICAL CHATBOT")
43
+ st.subheader(f"Hello {st.session_state.name}! How can I assist you today!")
44
+
45
+ memory = lc_memory.ConversationBufferMemory(
46
+ chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"),
47
+ return_messages=True,
48
+ memory_key="chat_history",
49
+ )
50
+ st.sidebar.markdown("## Feedback Scale")
51
+ feedback_option = (
52
+ "thumbs" if st.sidebar.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "faces"
53
+ )
54
+
55
+ with st.sidebar:
56
+ model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"])
57
+ temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001)
58
+ n_docs = st.number_input("**Number of retrieved documents**", min_value=0, max_value=10, value=5, step=1)
59
+
60
+ if st.sidebar.button("Clear message history"):
61
+ print("Clearing message history")
62
+ memory.clear()
63
+
64
+ retriever = retriever(n_docs=n_docs)
65
+ # Create Chain
66
+ chain = get_expression_chain(retriever,model_name,temp)
67
+
68
+ for msg in st.session_state.langchain_messages:
69
+ avatar = "πŸ’" if msg.type == "ai" else None
70
+ with st.chat_message(msg.type, avatar=avatar):
71
+ st.markdown(msg.content)
72
+
73
+
74
+ prompt = st.chat_input(placeholder="What do you need to know in the medical field ?")
75
+
76
+ if prompt :
77
+ with st.chat_message("user"):
78
+ st.write(prompt)
79
+
80
+ with st.chat_message("assistant", avatar="πŸ’"):
81
+ message_placeholder = st.empty()
82
+ full_response = ""
83
+ # Define the basic input structure for the chains
84
+ input_dict = {"input": prompt.lower()}
85
+ used_docs = retriever.get_relevant_documents(prompt.lower())
86
+
87
+ with collect_runs() as cb:
88
+ for chunk in chain.stream(input_dict, config={"tags": ["MEDICAL CHATBOT"]}):
89
+ full_response += chunk.content
90
+ message_placeholder.markdown(full_response + "β–Œ")
91
+ memory.save_context(input_dict, {"output": full_response})
92
+ st.session_state.run_id = cb.traced_runs[0].id
93
+ message_placeholder.markdown(full_response)
94
+ if used_docs :
95
+ docs_content = "\n\n".join(
96
+ [
97
+ f"Doc {i+1}:\n"
98
+ f"Source: {doc.metadata['source']}\n"
99
+ f"Title: {doc.metadata['title']}\n"
100
+ f"Content: {doc.page_content}\n"
101
+ for i, doc in enumerate(used_docs)
102
+ ]
103
+ )
104
+ with st.sidebar:
105
+ st.download_button(
106
+ label="Consulted Documents",
107
+ data=docs_content,
108
+ file_name="Consulted_documents.txt",
109
+ mime="text/plain",
110
+ )
111
+
112
+ with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
113
+ run_id = st.session_state.run_id
114
+ question_embedding = get_embeddings(prompt)
115
+ answer_embedding = get_embeddings(full_response)
116
+ # Add question and answer to Qdrant
117
+ qdrant_client.upload_collection(
118
+ collection_name="chat-history",
119
+ payload=[
120
+ {"text": prompt, "type": "question", "question_ID": run_id},
121
+ {"text": full_response, "type": "answer", "question_ID": run_id, "used_docs":used_docs}
122
+ ],
123
+ vectors=[
124
+ question_embedding,
125
+ answer_embedding,
126
+ ],
127
+ parallel=4,
128
+ max_retries=3,
129
+ )
130
+
131
+
132
+
133
+ if st.session_state.get("run_id"):
134
+ run_id = st.session_state.run_id
135
+ feedback = streamlit_feedback(
136
+ feedback_type=feedback_option,
137
+ optional_text_label="[Optional] Please provide an explanation",
138
+ key=f"feedback_{run_id}",
139
+ )
140
+
141
+ # Define score mappings for both "thumbs" and "faces" feedback systems
142
+ score_mappings = {
143
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
144
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
145
+ }
146
+
147
+ # Get the score mapping based on the selected feedback option
148
+ scores = score_mappings[feedback_option]
149
+
150
+ if feedback:
151
+ # Get the score from the selected feedback option's score mapping
152
+ score = scores.get(feedback["score"])
153
+
154
+ if score is not None:
155
+ # Formulate feedback type string incorporating the feedback option
156
+ # and score value
157
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
158
+
159
+ # Record the feedback with the formulated feedback type string
160
+ # and optional comment
161
+ with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
162
+ feedback_record = client.create_feedback(
163
+ run_id,
164
+ feedback_type_str,
165
+ score=score,
166
+ comment=feedback.get("text"),
167
+ source_info={"profile":profile}
168
+ )
169
+ st.session_state.feedback = {
170
+ "feedback_id": str(feedback_record.id),
171
+ "score": score,
172
+ }
173
+ else:
174
+ st.warning("Invalid feedback score.")
175
+
176
+ with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
177
+ if feedback.get("text"):
178
+ comment = feedback.get("text")
179
+ feedback_embedding = get_embeddings(comment)
180
+ else:
181
+ comment = "no comment"
182
+ feedback_embedding = get_embeddings(comment)
183
+
184
+
185
+ qdrant_client.upload_collection(
186
+ collection_name="chat-history",
187
+ payload=[
188
+ {"text": comment,
189
+ "Score:":score,
190
+ "type": "feedback",
191
+ "question_ID": run_id,
192
+ "User_profile":profile}
193
+ ],
194
+ vectors=[
195
+ feedback_embedding
196
+ ],
197
+ parallel=4,
198
+ max_retries=3,
199
+ )
200
+
201
+
202
+