lhzstar commited on
Commit
15303cb
1 Parent(s): 436ce71

new commits

Browse files
Files changed (4) hide show
  1. app.py +79 -72
  2. celebbot.py +3 -3
  3. data.json +0 -0
  4. run_tts.py +1 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from celebbot import CelebBot
2
  import streamlit as st
 
3
  from streamlit_mic_recorder import speech_to_text
4
  from utils import *
5
 
@@ -7,7 +8,7 @@ from utils import *
7
  def main():
8
 
9
  hide_footer()
10
- model_list = ["flan-t5-large", "flan-t5-xl", "Falcon-7b-instruct"]
11
  celeb_data = get_celeb_data(f'data.json')
12
 
13
  st.sidebar.header("CelebChat")
@@ -22,80 +23,86 @@ def main():
22
  st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
23
  if "start_chat" not in st.session_state:
24
  st.session_state["start_chat"] = False
25
- if "prompt" not in st.session_state:
26
- st.session_state["prompt"] = None
27
-
28
- def start_chat(name, model_id):
29
- print(name, model_id)
30
- if name != '' and model_id != '':
31
- st.session_state["start_chat"] = True
32
- else:
33
- st.session_state["start_chat"] = False
34
-
35
- with st.sidebar.form("my_form"):
36
- print("enter form")
37
- st.session_state["celeb_name"] = st.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
38
- model_id=st.selectbox("Choose Your Flan-T5 model",options=model_list)
39
- st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
40
-
41
- st.form_submit_button(label="Start Chatting", on_click=start_chat, args=(st.session_state["celeb_name"], st.session_state["QA_model_path"]))
42
-
43
- if st.session_state["start_chat"]:
44
- celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
45
- knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
46
- st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
47
- get_tokenizer(st.session_state["QA_model_path"]),
48
- get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
49
- get_tokenizer(st.session_state["sentTr_model_path"]),
50
- get_auto_model(st.session_state["sentTr_model_path"]),
51
- *preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
52
- )
53
-
54
- dialogue_container = st.container()
55
- with dialogue_container:
56
- for message in st.session_state["messages"]:
57
- with st.chat_message(message["role"]):
58
- st.markdown(message["content"])
59
 
 
 
 
 
 
 
 
 
 
60
 
61
- if "_last_audio_id" not in st.session_state:
62
- st.session_state["_last_audio_id"] = 0
63
- with st.sidebar:
64
- prompt_from_audio =speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
65
- prompt_from_text = st.text_input('Or write something')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- if prompt_from_audio != None:
68
- st.session_state["prompt"] = prompt_from_audio
69
- elif prompt_from_text != None:
70
- st.session_state["prompt"] = prompt_from_text
71
- print(st.session_state["prompt"])
72
- if st.session_state["prompt"] != None and st.session_state["prompt"] != '':
73
- st.session_state["celeb_bot"].text = st.session_state["prompt"]
74
- # Display user message in chat message container
75
- with dialogue_container:
76
- st.chat_message("user").markdown(st.session_state["prompt"])
77
- # Add user message to chat history
78
- st.session_state["messages"].append({"role": "user", "content": st.session_state["prompt"]})
79
-
80
- # Add assistant response to chat history
81
- response = st.session_state["celeb_bot"].question_answer()
82
-
83
- # disable autoplay to play in HTML
84
- b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
85
- md = f"""
86
- <p>{response}</p>
87
- <audio controls autoplay style="display:none;">
88
- <source src="data:audio/wav;base64,{b64}" type="audio/wav">
89
- Your browser does not support the audio element.
90
- </audio>
91
- """
92
- with dialogue_container:
93
- st.chat_message("assistant").markdown(
94
- md,
95
- unsafe_allow_html=True,
96
- )
97
- # Display assistant response in chat message container
98
- st.session_state["messages"].append({"role": "assistant", "content": response})
99
 
100
 
101
  if __name__ == "__main__":
 
1
  from celebbot import CelebBot
2
  import streamlit as st
3
+ import time
4
  from streamlit_mic_recorder import speech_to_text
5
  from utils import *
6
 
 
8
  def main():
9
 
10
  hide_footer()
11
+ model_list = ["flan-t5-xl"]
12
  celeb_data = get_celeb_data(f'data.json')
13
 
14
  st.sidebar.header("CelebChat")
 
23
  st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
24
  if "start_chat" not in st.session_state:
25
  st.session_state["start_chat"] = False
26
+ if "prompt_from_audio" not in st.session_state:
27
+ st.session_state["prompt_from_audio"] = ""
28
+ if "prompt_from_text" not in st.session_state:
29
+ st.session_state["prompt_from_text"] = ""
30
+
31
+ def text_submit():
32
+ st.session_state["prompt_from_text"] = st.session_state.widget
33
+ st.session_state.widget = ''
34
+
35
+ st.session_state["celeb_name"] = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
36
+ model_id=st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
37
+ st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
40
+ knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
41
+ st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
42
+ get_tokenizer(st.session_state["QA_model_path"]),
43
+ get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
44
+ get_tokenizer(st.session_state["sentTr_model_path"]),
45
+ get_auto_model(st.session_state["sentTr_model_path"]),
46
+ *preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
47
+ )
48
 
49
+ dialogue_container = st.container()
50
+ with dialogue_container:
51
+ for message in st.session_state["messages"]:
52
+ with st.chat_message(message["role"]):
53
+ st.markdown(message["content"])
54
+
55
+ if "_last_audio_id" not in st.session_state:
56
+ st.session_state["_last_audio_id"] = 0
57
+ with st.sidebar:
58
+ st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
59
+ st.text_input('Or write something', key='widget', on_change=text_submit)
60
+
61
+ if st.session_state["prompt_from_audio"] != None:
62
+ prompt = st.session_state["prompt_from_audio"]
63
+ elif st.session_state["prompt_from_text"] != None:
64
+ prompt = st.session_state["prompt_from_text"]
65
+
66
+ if prompt != None and prompt != '':
67
+ st.session_state["celeb_bot"].text = prompt
68
+ # Display user message in chat message container
69
+ with dialogue_container:
70
+ st.chat_message("user").markdown(prompt)
71
+ # Add user message to chat history
72
+ st.session_state["messages"].append({"role": "user", "content": prompt})
73
+
74
+ # Add assistant response to chat history
75
+ response = st.session_state["celeb_bot"].question_answer()
76
 
77
+ # disable autoplay to play in HTML
78
+ wav, sr = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
79
+ md = f"""
80
+ <p>{response}</p>
81
+ """
82
+ with dialogue_container:
83
+ st.chat_message("assistant").markdown(
84
+ md,
85
+ unsafe_allow_html=True,
86
+ )
87
+
88
+ # Play the audio (non-blocking)
89
+ import sounddevice as sd
90
+ try:
91
+ sd.stop()
92
+ sd.play(wav, sr)
93
+ time_span = len(wav)//sr + 1
94
+ time.sleep(time_span)
95
+
96
+ except sd.PortAudioError as e:
97
+ print("\nCaught exception: %s" % repr(e))
98
+ print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
99
+ except:
100
+ raise
101
+ # Display assistant response in chat message container
102
+ st.session_state["messages"].append({"role": "assistant", "content": response})
103
+
104
+ st.session_state["prompt_from_audio"] = ""
105
+ st.session_state["prompt_from_text"] = ""
 
 
 
106
 
107
 
108
  if __name__ == "__main__":
celebbot.py CHANGED
@@ -103,12 +103,12 @@ class CelebBot():
103
  ## have a conversation
104
  else:
105
  if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:
106
- instruction1 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.'
107
 
108
  knowledge = self.retrieve_knowledge_assertions()
109
  else:
110
- instruction1 = f'[Instruction] You need to answer the question based on commonsense.'
111
- query = f"{instruction1} [knowledge] {knowledge} [question] {self.text}"
112
  input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
113
  outputs = self.QA_model.generate(input_ids, max_length=1024)
114
  self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
103
  ## have a conversation
104
  else:
105
  if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:
106
+ instruction1 = f'You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.'
107
 
108
  knowledge = self.retrieve_knowledge_assertions()
109
  else:
110
+ instruction1 = f'You need to answer the question based on commonsense.'
111
+ query = f"Context: {instruction1} {knowledge}\n\nQuestion: {self.text}\n\nAnswer:"
112
  input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
113
  outputs = self.QA_model.generate(input_ids, max_length=1024)
114
  self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
data.json CHANGED
The diff for this file is too large to render. See raw diff
 
run_tts.py CHANGED
@@ -109,11 +109,7 @@ def tts(text, embed_name, nlp, autoplay=True):
109
  print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
110
  except:
111
  raise
112
- bytes_wav = bytes()
113
- byte_io = io.BytesIO(bytes_wav)
114
- write(byte_io, synthesizer.sample_rate, wav.astype(np.float32))
115
- result_bytes = byte_io.read()
116
- return base64.b64encode(result_bytes).decode()
117
 
118
 
119
  if __name__ == "__main__":
 
109
  print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
110
  except:
111
  raise
112
+ return wav, synthesizer.sample_rate
 
 
 
 
113
 
114
 
115
  if __name__ == "__main__":