SwatGarg commited on
Commit
2250554
1 Parent(s): d159277

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+ from streamlit_extras.colored_header import colored_header
4
+ from streamlit_extras.add_vertical_space import add_vertical_space
5
+ from streamlit_mic_recorder import speech_to_text
6
+ from model_pipeline import ModelPipeLine
7
+
8
+ from gtts import gTTS
9
+ from io import BytesIO
10
+
11
+ mdl = ModelPipeLine()
12
+ final_chain = mdl.create_final_chain()
13
+
14
+ st.set_page_config(page_title="PeacePal")
15
+
16
+ st.title('Omdena HYD: Mental Health counselor 🌱')
17
+
18
+ ## generated stores AI generated responses
19
+ if 'generated' not in st.session_state:
20
+ st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"]
21
+ ## past stores User's questions
22
+ if 'past' not in st.session_state:
23
+ st.session_state['past'] = ['Hi!']
24
+
25
+ # Layout of input/response containers
26
+
27
+ colored_header(label='', description='', color_name='blue-30')
28
+ response_container = st.container()
29
+ input_container = st.container()
30
+
31
+ # User input
32
+ ## Function for taking user provided prompt as input
33
+ def get_text():
34
+ input_text = st.text_input("You: ", "", key="input")
35
+ return input_text
36
+
37
+ def generate_response(prompt):
38
+ response = mdl.call_conversational_rag(prompt,final_chain)
39
+ return response['answer']
40
+
41
+ def text_to_speech(text):
42
+ # Use gTTS to convert text to speech
43
+ tts = gTTS(text=text, lang='en')
44
+ # Save the speech as bytes in memory
45
+ fp = BytesIO()
46
+ tts.write_to_fp(fp)
47
+ return fp
48
+
49
+ def speech_recognition_callback():
50
+ # Ensure that speech output is available
51
+ if st.session_state.my_stt_output is None:
52
+ st.session_state.p01_error_message = "Please record your response again."
53
+ return
54
+
55
+ # Clear any previous error messages
56
+ st.session_state.p01_error_message = None
57
+
58
+ # Store the speech output in the session state
59
+ st.session_state.speech_input = st.session_state.my_stt_output
60
+
61
+
62
+ ## Applying the user input box
63
+ with input_container:
64
+ # Add a radio button to choose input mode
65
+ input_mode = st.radio("Select input mode:", ["Text", "Speech"])
66
+
67
+ if input_mode == "Speech":
68
+ # Use the speech_to_text function to capture speech input
69
+ speech_input = speech_to_text(
70
+ key='my_stt',
71
+ callback=speech_recognition_callback
72
+ )
73
+
74
+ # Check if speech input is available
75
+ if 'speech_input' in st.session_state and st.session_state.speech_input:
76
+ # Display the speech input
77
+ st.text(f"Speech Input: {st.session_state.speech_input}")
78
+
79
+ # Process the speech input as a query
80
+ query = st.session_state.speech_input
81
+ with st.spinner("processing....."):
82
+ response = generate_response(query)
83
+ st.session_state.past.append(query)
84
+ st.session_state.generated.append(response)
85
+
86
+ # Convert the response to speech
87
+ speech_fp = text_to_speech(response)
88
+ # Play the speech
89
+ st.audio(speech_fp, format='audio/mp3')
90
+ else:
91
+ # Add a text input field for query
92
+ query = st.text_input("Query: ", key="input")
93
+
94
+ # Process the query if it's not empty
95
+ if query:
96
+ with st.spinner("typing....."):
97
+ response = generate_response(query)
98
+ st.session_state.past.append(query)
99
+ st.session_state.generated.append(response)
100
+
101
+ # Convert the response to speech
102
+ speech_fp = text_to_speech(response)
103
+ # Play the speech
104
+ st.audio(speech_fp, format='audio/mp3')
105
+
106
+ ## Conditional display of AI generated responses as a function of user provided prompts
107
+ with response_container:
108
+ if st.session_state['generated']:
109
+ for i in range(len(st.session_state['generated'])):
110
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
111
+ message(st.session_state["generated"][i], key=str(i))
112
+