Spaces:
Runtime error
Runtime error
File size: 5,305 Bytes
6bc94ac 15303cb 436ce71 6bc94ac 436ce71 6bc94ac b7b070d 436ce71 e916883 6bc94ac fb6ade2 6bc94ac 15303cb fb6ade2 15303cb 325f09c 15303cb e916883 15303cb 436ce71 15303cb 436ce71 15303cb 325f09c 15303cb 325f09c 15303cb 6bc94ac 15303cb a0194f4 15303cb fb6ade2 a0194f4 15303cb 6bc94ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from celebbot import CelebBot
import streamlit as st
import time
from streamlit_mic_recorder import speech_to_text
from utils import *
def main():
hide_footer()
model_list = ["flan-t5-xl"]
celeb_data = get_celeb_data(f'data.json')
st.sidebar.header("CelebChat")
expander = st.sidebar.expander('About the app')
with expander:
st.markdown("Experience the ultimate celebrity chat demo with this app!")
expander = st.sidebar.expander('Disclaimer')
with expander:
st.markdown("CelebChat may produce inaccurate information about people, places, or facts.")
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "QA_model_path" not in st.session_state:
st.session_state["QA_model_path"] = "google/flan-t5-xl"
if "sentTr_model_path" not in st.session_state:
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
if "start_chat" not in st.session_state:
st.session_state["start_chat"] = False
if "prompt_from_audio" not in st.session_state:
st.session_state["prompt_from_audio"] = ""
if "prompt_from_text" not in st.session_state:
st.session_state["prompt_from_text"] = ""
if "celeb_bot" not in st.session_state:
st.session_state["celeb_bot"] = None
def text_submit():
st.session_state["prompt_from_text"] = st.session_state.text_input
st.session_state.text_input = ''
def example_submit(text):
st.session_state["prompt_from_text"] = text
st.session_state["celeb_name"] = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
model_id=st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
get_tokenizer(st.session_state["QA_model_path"]),
get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
get_tokenizer(st.session_state["sentTr_model_path"]),
get_auto_model(st.session_state["sentTr_model_path"]),
*preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
)
dialogue_container = st.container()
with dialogue_container:
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if "_last_audio_id" not in st.session_state:
st.session_state["_last_audio_id"] = 0
with st.sidebar:
st.write("You can record your question...")
st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
st.text_input('Or write something...', key='text_input', on_change=text_submit)
st.write("Example questions:")
example1 = "Hello! Did you win an Oscar?"
st.button(example1, on_click=example_submit, args=[example1])
example2 = "Hi! What is your profession?"
st.button(example2, on_click=example_submit, args=[example2])
example3 = "Can you tell me about your family background?"
st.button(example3, on_click=example_submit, args=[example3])
if st.session_state["prompt_from_audio"] != None:
prompt = st.session_state["prompt_from_audio"]
elif st.session_state["prompt_from_text"] != None:
prompt = st.session_state["prompt_from_text"]
if prompt != None and prompt != '':
st.session_state["celeb_bot"].text = prompt
# Display user message in chat message container
with dialogue_container:
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state["messages"].append({"role": "user", "content": prompt})
# Add assistant response to chat history
response = st.session_state["celeb_bot"].question_answer()
# disable autoplay to play in HTML
b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
md = f"""
<p>{response}</p>
<audio controls controlsList="autoplay nodownload">
<source src="data:audio/wav;base64,{b64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
"""
with dialogue_container:
st.chat_message("assistant").markdown(
md,
unsafe_allow_html=True,
)
# Display assistant response in chat message container
st.session_state["messages"].append({"role": "assistant", "content": response})
st.session_state["prompt_from_audio"] = ""
st.session_state["prompt_from_text"] = ""
if __name__ == "__main__":
main()
|