Pratik Dwivedi commited on
Commit
729a463
1 Parent(s): 4e20e20
Files changed (1) hide show
  1. app.py +44 -12
app.py CHANGED
@@ -1,22 +1,54 @@
1
  import streamlit as st
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def main():
5
  st.title("Health Chatbot")
6
 
7
- if 'messages' not in st.session_state:
 
 
 
 
 
 
 
8
  st.session_state.messages = []
9
 
10
-
11
- user_input = st.chat_input("You", key="user_input")
12
- if user_input:
13
- st.session_state.messages.append({"sender": "user", "message": user_input})
14
- # display all the messages already in the session state
15
- for message in st.session_state.messages:
16
- if message["sender"] == "user":
17
- st.write("You: ", message["message"])
18
- else:
19
- st.write("Bot: ", message["message"])
20
-
 
 
 
 
 
21
  if __name__ == "__main__":
22
  main()
 
1
  import streamlit as st
2
+ from langchain.memory import ConversationBufferMemory
3
+ from langchain.chains import ConversationalRetrievalChain
4
+ from langchain_community.llms import HuggingFaceHub
5
 
6
+ def get_conversation(model):
7
+
8
+ memory = ConversationBufferMemory(memory_key="messages", return_messages=True)
9
+
10
+ conversation_chain = ConversationalRetrievalChain.from_llm(
11
+ llm=model,
12
+ # retriever=vectorstore.as_retriever(),
13
+ memory=memory)
14
+
15
+ return conversation_chain
16
+
17
+ def get_response(conversation_chain, query):
18
+ # get the response
19
+ response = conversation_chain.invoke(query)
20
+ response = response["result"]
21
+ answer = response.split('\nHelpful Answer: ')[1]
22
+ return answer
23
 
24
  def main():
25
  st.title("Health Chatbot")
26
 
27
+ print("Loading LLM from HuggingFace")
28
+ with st.spinner('Loading LLM from HuggingFace...'):
29
+ llm = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={"temperature":0.7, "max_new_tokens":512, "top_p":0.95, "top_k":50})
30
+
31
+ if "messages" not in st.session_state:
32
+ st.session_state.messages = []
33
+
34
+ if st.button("Clear Chat"):
35
  st.session_state.messages = []
36
 
37
+ for message in st.session_state.messages:
38
+ if message["role"] == "user":
39
+ st.chat_message("user").markdown(message["content"])
40
+ else:
41
+ st.chat_message("bot").markdown(message["content"])
42
+
43
+ conversation_chain = get_conversation(llm)
44
+
45
+ user_prompt = st.chat_input("ask a question", key="user")
46
+ if user_prompt:
47
+ st.chat_message("user").markdown(user_prompt)
48
+ st.session_state.messages.append({"role": "user", "content": user_prompt})
49
+ response = get_response(conversation_chain, user_prompt)
50
+ st.chat_message("bot").markdown(response)
51
+ st.session_state.messages.append({"role": "bot", "content": response})
52
+
53
  if __name__ == "__main__":
54
  main()