File size: 5,305 Bytes
6bc94ac
 
15303cb
436ce71
6bc94ac
 
 
 
436ce71
6bc94ac
b7b070d
436ce71
 
 
 
 
e916883
 
 
 
 
6bc94ac
 
 
fb6ade2
6bc94ac
 
 
 
15303cb
 
 
 
fb6ade2
 
15303cb
 
325f09c
 
 
 
 
15303cb
 
 
e916883
15303cb
436ce71
15303cb
 
 
 
 
 
 
 
 
436ce71
15303cb
 
 
 
 
 
 
 
 
325f09c
15303cb
325f09c
 
 
 
 
 
 
 
 
15303cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc94ac
15303cb
a0194f4
15303cb
 
fb6ade2
a0194f4
 
 
15303cb
 
 
 
 
 
 
 
 
 
 
6bc94ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from celebbot import CelebBot
import streamlit as st
import time
from streamlit_mic_recorder import speech_to_text
from utils import *


def main():

    hide_footer()
    model_list = ["flan-t5-xl"]
    celeb_data = get_celeb_data(f'data.json')

    st.sidebar.header("CelebChat")
    expander = st.sidebar.expander('About the app')
    with expander:
        st.markdown("Experience the ultimate celebrity chat demo with this app!")

    expander = st.sidebar.expander('Disclaimer')
    with expander:
        st.markdown("CelebChat may produce inaccurate information about people, places, or facts.")
    if "messages" not in st.session_state:
        st.session_state["messages"] = []
    if "QA_model_path" not in st.session_state:          
        st.session_state["QA_model_path"] = "google/flan-t5-xl"
    if "sentTr_model_path" not in st.session_state:          
        st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
    if "start_chat" not in st.session_state:          
        st.session_state["start_chat"] = False
    if "prompt_from_audio" not in st.session_state:          
        st.session_state["prompt_from_audio"] = ""
    if "prompt_from_text" not in st.session_state:          
        st.session_state["prompt_from_text"] = ""
    if "celeb_bot" not in st.session_state:          
        st.session_state["celeb_bot"] = None

    def text_submit():
        st.session_state["prompt_from_text"] = st.session_state.text_input
        st.session_state.text_input = ''

    def example_submit(text):
        st.session_state["prompt_from_text"] = text   

    st.session_state["celeb_name"] = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
    model_id=st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)

    st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id

    celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
    knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
    st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"], 
                    get_tokenizer(st.session_state["QA_model_path"]), 
                    get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]), 
                    get_tokenizer(st.session_state["sentTr_model_path"]), 
                    get_auto_model(st.session_state["sentTr_model_path"]), 
                    *preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
                    )

    dialogue_container = st.container()
    with dialogue_container:
        for message in st.session_state["messages"]:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

    if "_last_audio_id" not in st.session_state:
        st.session_state["_last_audio_id"] = 0
    with st.sidebar:
        st.write("You can record your question...")
        st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
        st.text_input('Or write something...', key='text_input', on_change=text_submit)
        st.write("Example questions:")

        example1 = "Hello! Did you win an Oscar?"
        st.button(example1, on_click=example_submit, args=[example1])
        example2 = "Hi! What is your profession?"
        st.button(example2, on_click=example_submit, args=[example2])
        example3 = "Can you tell me about your family background?"
        st.button(example3, on_click=example_submit, args=[example3])
    
    if st.session_state["prompt_from_audio"] != None:
        prompt = st.session_state["prompt_from_audio"] 
    elif st.session_state["prompt_from_text"] != None:
        prompt = st.session_state["prompt_from_text"]    

    if prompt != None and prompt != '':
        st.session_state["celeb_bot"].text = prompt
        # Display user message in chat message container
        with dialogue_container:
            st.chat_message("user").markdown(prompt)
        # Add user message to chat history
        st.session_state["messages"].append({"role": "user", "content": prompt})

        # Add assistant response to chat history
        response = st.session_state["celeb_bot"].question_answer()
        
        # disable autoplay to play in HTML
        b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
        md = f"""
        <p>{response}</p>
        <audio controls controlsList="autoplay nodownload">
        <source src="data:audio/wav;base64,{b64}" type="audio/wav">
        Your browser does not support the audio element.
        </audio>
        """
        with dialogue_container:
            st.chat_message("assistant").markdown(
                md,
                unsafe_allow_html=True,
            )
        # Display assistant response in chat message container
        st.session_state["messages"].append({"role": "assistant", "content": response})

        st.session_state["prompt_from_audio"] = ""   
        st.session_state["prompt_from_text"] = ""   


if __name__ == "__main__":
    main()