SwatGarg commited on
Commit
74a34d5
1 Parent(s): 7cfc4fe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+ from streamlit_extras.colored_header import colored_header
4
+ from streamlit_extras.add_vertical_space import add_vertical_space
5
+ from streamlit_mic_recorder import speech_to_text
6
+ from model_pipelineV2 import ModelPipeLine
7
+ import pandas as pd
8
+ from gtts import gTTS
9
+ from io import BytesIO
10
+
11
+ mdl = ModelPipeLine()
12
+ final_chain = mdl.create_final_chain()
13
+
14
+ st.set_page_config(page_title="PeacePal")
15
+
16
+ st.title('PeacePal 🌱')
17
+
18
+ states = [
19
+ "Negative",
20
+ "Neutral",
21
+ "Positive",
22
+ ]
23
+
24
+
25
+ def display_q_table(states):
26
+ values = [0,1,2]
27
+ q_table_dict = {"State": states,
28
+ "values":values}
29
+ q_table_df = pd.DataFrame(q_table_dict)
30
+ return q_table_df
31
+
32
+ ## generated stores AI generated responses
33
+ if 'generated' not in st.session_state:
34
+ st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"]
35
+ ## past stores User's questions
36
+ if 'past' not in st.session_state:
37
+ st.session_state['past'] = ['Hi']
38
+
39
+ if "user_sentiment" not in st.session_state:
40
+ st.session_state.user_sentiment = "Neutral"
41
+
42
+ # Layout of input/response containers
43
+
44
+ colored_header(label='', description='', color_name='blue-30')
45
+ response_container = st.container()
46
+ input_container = st.container()
47
+
48
+ # User input
49
+ ## Function for taking user provided prompt as input
50
+ def get_text():
51
+ input_text = st.text_input("You: ", "", key="input")
52
+ return input_text
53
+
54
+ def generate_response(prompt):
55
+ sentiment = mdl.predict_classification(prompt)
56
+ response = mdl.call_conversational_rag(prompt,final_chain)
57
+ return response['answer'],sentiment
58
+
59
+ def text_to_speech(text):
60
+ # Use gTTS to convert text to speech
61
+ tts = gTTS(text=text, lang='en')
62
+ # Save the speech as bytes in memory
63
+ fp = BytesIO()
64
+ tts.write_to_fp(fp)
65
+ return fp
66
+
67
+ def speech_recognition_callback():
68
+ # Ensure that speech output is available
69
+ if st.session_state.my_stt_output is None:
70
+ st.session_state.p01_error_message = "Please record your response again."
71
+ return
72
+
73
+ # Clear any previous error messages
74
+ st.session_state.p01_error_message = None
75
+
76
+ # Store the speech output in the session state
77
+ st.session_state.speech_input = st.session_state.my_stt_output
78
+
79
+
80
+ input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"])
81
+ ## Applying the user input box
82
+ query = None
83
+ with input_container:
84
+ detected_sentiment = None
85
+ if input_mode == "Speech":
86
+ # Use the speech_to_text function to capture speech input
87
+ speech_input = speech_to_text(
88
+ key='my_stt',
89
+ callback=speech_recognition_callback
90
+ )
91
+
92
+ # Check if speech input is available
93
+ if 'speech_input' in st.session_state and st.session_state.speech_input:
94
+ # Display the speech input
95
+ # st.text(f"Speech Input: {st.session_state.speech_input}")
96
+
97
+ # Process the speech input as a query
98
+ query = st.session_state.speech_input
99
+ with st.spinner("processing....."):
100
+ response,detected_sentiment = generate_response(query)
101
+ st.session_state.past.append(query)
102
+ st.session_state.generated.append(response)
103
+ st.session_state.speech_input = None
104
+ # Convert the response to speech
105
+ speech_fp = text_to_speech(response)
106
+ # Play the speech
107
+ st.audio(speech_fp, format='audio/mp3')
108
+
109
+ else:
110
+ # Add a text input field for query
111
+ query = st.text_input("Query: ", key="input")
112
+
113
+ # Process the query if it's not empty
114
+ if query:
115
+ with st.spinner("processing....."):
116
+ response,detected_sentiment = generate_response(query)
117
+ st.session_state.past.append(query)
118
+ st.session_state.generated.append(response)
119
+ query = None
120
+ # Convert the response to speech
121
+ speech_fp = text_to_speech(response)
122
+ # Play the speech
123
+ st.audio(speech_fp, format='audio/mp3')
124
+ if detected_sentiment == 0:
125
+ st.session_state.user_sentiment = 'Negative'
126
+ elif detected_sentiment == 1:
127
+ st.session_state.user_sentiment = 'Neutral'
128
+ elif detected_sentiment == 1:
129
+ st.session_state.user_sentiment = 'Positive'
130
+ else:
131
+ st.session_state.user_sentiment = 'Neutral'
132
+
133
+
134
+ ## Conditional display of AI generated responses as a function of user provided prompts
135
+ with response_container:
136
+ if st.session_state['generated']:
137
+ for i in range(len(st.session_state['generated'])):
138
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
139
+ message(st.session_state["generated"][i], key=str(i))
140
+
141
+ with st.sidebar.expander("Sentiment Analysis"):
142
+ # Use the values stored in session state
143
+
144
+ st.write(
145
+ f"- Detected User Tone: {st.session_state.user_sentiment}")
146
+
147
+ # Display Q-table
148
+ st.dataframe(display_q_table(states))