Spaces:
Running
Running
Goodnight7
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain import memory as lc_memory
|
3 |
+
from langsmith import Client
|
4 |
+
from streamlit_feedback import streamlit_feedback
|
5 |
+
#from utils import get_expression_chain, retriever, get_embeddings, create_qdrant_collection
|
6 |
+
from langchain_core.tracers.context import collect_runs
|
7 |
+
from qdrant_client import QdrantClient
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import os
|
10 |
+
if "access_granted" not in st.session_state:
|
11 |
+
st.session_state.access_granted = False
|
12 |
+
if "profile" not in st.session_state:
|
13 |
+
st.session_state.profile = None
|
14 |
+
if "name" not in st.session_state:
|
15 |
+
st.session_state.name = None
|
16 |
+
if not st.session_state.access_granted:
|
17 |
+
# Profile input section
|
18 |
+
st.title("User Profile")
|
19 |
+
name = st.text_input("Name")
|
20 |
+
profile_selector = st.selectbox("Profile", options=["Patient", "Doctor"] )
|
21 |
+
|
22 |
+
profile = profile_selector
|
23 |
+
if profile and name:
|
24 |
+
d = False
|
25 |
+
else:
|
26 |
+
d = True
|
27 |
+
|
28 |
+
submission = st.button("Submit", disabled=d)
|
29 |
+
|
30 |
+
if submission:
|
31 |
+
st.session_state.profile = profile
|
32 |
+
st.session_state.name = name
|
33 |
+
st.session_state.access_granted = True # Grant access to main app
|
34 |
+
st.rerun() # Reload the app
|
35 |
+
else:
|
36 |
+
load_dotenv()
|
37 |
+
profile = st.session_state.profile
|
38 |
+
client = Client()
|
39 |
+
qdrant_api=os.getenv("QDRANT_API_KEY")
|
40 |
+
qdrant_url=os.getenv("QDRANT_URL")
|
41 |
+
qdrant_client = QdrantClient(qdrant_url ,api_key=qdrant_api)
|
42 |
+
st.set_page_config(page_title = "MEDICAL CHATBOT")
|
43 |
+
st.subheader(f"Hello {st.session_state.name}! How can I assist you today!")
|
44 |
+
|
45 |
+
memory = lc_memory.ConversationBufferMemory(
|
46 |
+
chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"),
|
47 |
+
return_messages=True,
|
48 |
+
memory_key="chat_history",
|
49 |
+
)
|
50 |
+
st.sidebar.markdown("## Feedback Scale")
|
51 |
+
feedback_option = (
|
52 |
+
"thumbs" if st.sidebar.toggle(label="`Faces` β `Thumbs`", value=False) else "faces"
|
53 |
+
)
|
54 |
+
|
55 |
+
with st.sidebar:
|
56 |
+
model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"])
|
57 |
+
temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001)
|
58 |
+
n_docs = st.number_input("**Number of retrieved documents**", min_value=0, max_value=10, value=5, step=1)
|
59 |
+
|
60 |
+
if st.sidebar.button("Clear message history"):
|
61 |
+
print("Clearing message history")
|
62 |
+
memory.clear()
|
63 |
+
|
64 |
+
retriever = retriever(n_docs=n_docs)
|
65 |
+
# Create Chain
|
66 |
+
chain = get_expression_chain(retriever,model_name,temp)
|
67 |
+
|
68 |
+
for msg in st.session_state.langchain_messages:
|
69 |
+
avatar = "π" if msg.type == "ai" else None
|
70 |
+
with st.chat_message(msg.type, avatar=avatar):
|
71 |
+
st.markdown(msg.content)
|
72 |
+
|
73 |
+
|
74 |
+
prompt = st.chat_input(placeholder="What do you need to know in the medical field ?")
|
75 |
+
|
76 |
+
if prompt :
|
77 |
+
with st.chat_message("user"):
|
78 |
+
st.write(prompt)
|
79 |
+
|
80 |
+
with st.chat_message("assistant", avatar="π"):
|
81 |
+
message_placeholder = st.empty()
|
82 |
+
full_response = ""
|
83 |
+
# Define the basic input structure for the chains
|
84 |
+
input_dict = {"input": prompt.lower()}
|
85 |
+
used_docs = retriever.get_relevant_documents(prompt.lower())
|
86 |
+
|
87 |
+
with collect_runs() as cb:
|
88 |
+
for chunk in chain.stream(input_dict, config={"tags": ["MEDICAL CHATBOT"]}):
|
89 |
+
full_response += chunk.content
|
90 |
+
message_placeholder.markdown(full_response + "β")
|
91 |
+
memory.save_context(input_dict, {"output": full_response})
|
92 |
+
st.session_state.run_id = cb.traced_runs[0].id
|
93 |
+
message_placeholder.markdown(full_response)
|
94 |
+
if used_docs :
|
95 |
+
docs_content = "\n\n".join(
|
96 |
+
[
|
97 |
+
f"Doc {i+1}:\n"
|
98 |
+
f"Source: {doc.metadata['source']}\n"
|
99 |
+
f"Title: {doc.metadata['title']}\n"
|
100 |
+
f"Content: {doc.page_content}\n"
|
101 |
+
for i, doc in enumerate(used_docs)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
with st.sidebar:
|
105 |
+
st.download_button(
|
106 |
+
label="Consulted Documents",
|
107 |
+
data=docs_content,
|
108 |
+
file_name="Consulted_documents.txt",
|
109 |
+
mime="text/plain",
|
110 |
+
)
|
111 |
+
|
112 |
+
with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
|
113 |
+
run_id = st.session_state.run_id
|
114 |
+
question_embedding = get_embeddings(prompt)
|
115 |
+
answer_embedding = get_embeddings(full_response)
|
116 |
+
# Add question and answer to Qdrant
|
117 |
+
qdrant_client.upload_collection(
|
118 |
+
collection_name="chat-history",
|
119 |
+
payload=[
|
120 |
+
{"text": prompt, "type": "question", "question_ID": run_id},
|
121 |
+
{"text": full_response, "type": "answer", "question_ID": run_id, "used_docs":used_docs}
|
122 |
+
],
|
123 |
+
vectors=[
|
124 |
+
question_embedding,
|
125 |
+
answer_embedding,
|
126 |
+
],
|
127 |
+
parallel=4,
|
128 |
+
max_retries=3,
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
if st.session_state.get("run_id"):
|
134 |
+
run_id = st.session_state.run_id
|
135 |
+
feedback = streamlit_feedback(
|
136 |
+
feedback_type=feedback_option,
|
137 |
+
optional_text_label="[Optional] Please provide an explanation",
|
138 |
+
key=f"feedback_{run_id}",
|
139 |
+
)
|
140 |
+
|
141 |
+
# Define score mappings for both "thumbs" and "faces" feedback systems
|
142 |
+
score_mappings = {
|
143 |
+
"thumbs": {"π": 1, "π": 0},
|
144 |
+
"faces": {"π": 1, "π": 0.75, "π": 0.5, "π": 0.25, "π": 0},
|
145 |
+
}
|
146 |
+
|
147 |
+
# Get the score mapping based on the selected feedback option
|
148 |
+
scores = score_mappings[feedback_option]
|
149 |
+
|
150 |
+
if feedback:
|
151 |
+
# Get the score from the selected feedback option's score mapping
|
152 |
+
score = scores.get(feedback["score"])
|
153 |
+
|
154 |
+
if score is not None:
|
155 |
+
# Formulate feedback type string incorporating the feedback option
|
156 |
+
# and score value
|
157 |
+
feedback_type_str = f"{feedback_option} {feedback['score']}"
|
158 |
+
|
159 |
+
# Record the feedback with the formulated feedback type string
|
160 |
+
# and optional comment
|
161 |
+
with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
|
162 |
+
feedback_record = client.create_feedback(
|
163 |
+
run_id,
|
164 |
+
feedback_type_str,
|
165 |
+
score=score,
|
166 |
+
comment=feedback.get("text"),
|
167 |
+
source_info={"profile":profile}
|
168 |
+
)
|
169 |
+
st.session_state.feedback = {
|
170 |
+
"feedback_id": str(feedback_record.id),
|
171 |
+
"score": score,
|
172 |
+
}
|
173 |
+
else:
|
174 |
+
st.warning("Invalid feedback score.")
|
175 |
+
|
176 |
+
with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
|
177 |
+
if feedback.get("text"):
|
178 |
+
comment = feedback.get("text")
|
179 |
+
feedback_embedding = get_embeddings(comment)
|
180 |
+
else:
|
181 |
+
comment = "no comment"
|
182 |
+
feedback_embedding = get_embeddings(comment)
|
183 |
+
|
184 |
+
|
185 |
+
qdrant_client.upload_collection(
|
186 |
+
collection_name="chat-history",
|
187 |
+
payload=[
|
188 |
+
{"text": comment,
|
189 |
+
"Score:":score,
|
190 |
+
"type": "feedback",
|
191 |
+
"question_ID": run_id,
|
192 |
+
"User_profile":profile}
|
193 |
+
],
|
194 |
+
vectors=[
|
195 |
+
feedback_embedding
|
196 |
+
],
|
197 |
+
parallel=4,
|
198 |
+
max_retries=3,
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
|