|
import streamlit as st |
|
from streamlit_chat import message |
|
from streamlit_extras.colored_header import colored_header |
|
from streamlit_extras.add_vertical_space import add_vertical_space |
|
from streamlit_mic_recorder import speech_to_text |
|
from model_pipeline import ModelPipeLine |
|
|
|
from gtts import gTTS |
|
from io import BytesIO |
|
|
|
mdl = ModelPipeLine() |
|
final_chain = mdl.create_final_chain() |
|
|
|
st.set_page_config(page_title="PeacePal") |
|
|
|
st.title('Omdena HYD: Mental Health counselor 🌱') |
|
|
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = ['Hi!'] |
|
|
|
|
|
|
|
colored_header(label='', description='', color_name='blue-30') |
|
response_container = st.container() |
|
input_container = st.container() |
|
|
|
|
|
|
|
def get_text(): |
|
input_text = st.text_input("You: ", "", key="input") |
|
return input_text |
|
|
|
def generate_response(prompt): |
|
response = mdl.call_conversational_rag(prompt,final_chain) |
|
return response['answer'] |
|
|
|
def text_to_speech(text): |
|
|
|
tts = gTTS(text=text, lang='en') |
|
|
|
fp = BytesIO() |
|
tts.write_to_fp(fp) |
|
return fp |
|
|
|
def speech_recognition_callback(): |
|
|
|
if st.session_state.my_stt_output is None: |
|
st.session_state.p01_error_message = "Please record your response again." |
|
return |
|
|
|
|
|
st.session_state.p01_error_message = None |
|
|
|
|
|
st.session_state.speech_input = st.session_state.my_stt_output |
|
|
|
|
|
|
|
with input_container: |
|
|
|
input_mode = st.radio("Select input mode:", ["Text", "Speech"]) |
|
|
|
if input_mode == "Speech": |
|
|
|
speech_input = speech_to_text( |
|
key='my_stt', |
|
callback=speech_recognition_callback |
|
) |
|
|
|
|
|
if 'speech_input' in st.session_state and st.session_state.speech_input: |
|
|
|
st.text(f"Speech Input: {st.session_state.speech_input}") |
|
|
|
|
|
query = st.session_state.speech_input |
|
with st.spinner("processing....."): |
|
response = generate_response(query) |
|
st.session_state.past.append(query) |
|
st.session_state.generated.append(response) |
|
|
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
else: |
|
|
|
query = st.text_input("Query: ", key="input") |
|
|
|
|
|
if query: |
|
with st.spinner("typing....."): |
|
response = generate_response(query) |
|
st.session_state.past.append(query) |
|
st.session_state.generated.append(response) |
|
|
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
|
|
|
|
with response_container: |
|
if st.session_state['generated']: |
|
for i in range(len(st.session_state['generated'])): |
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(i)) |
|
|
|
|