hushell commited on
Commit
8c276e4
1 Parent(s): 036168d

add whisper STT

Browse files
Files changed (2) hide show
  1. app.py +44 -5
  2. requirements.txt +2 -0
app.py CHANGED
@@ -7,6 +7,40 @@ import openai
7
  from requests.models import ChunkedEncodingError
8
  from streamlit.components import v1
9
  from custom import css_code, js_code, set_context_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  st.set_page_config(page_title='ChatGPT Assistant', layout='wide', page_icon='🤖')
12
  # 自定义元素样式
@@ -153,8 +187,8 @@ with tap_set:
153
  key='context_level' + current_chat, help="表示每次会话中包含的历史对话次数,预设内容不计算在内。")
154
 
155
  st.markdown("模型参数:")
156
- st.selectbox("Model", ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"], index=0,
157
- help="[模型选择参考](https://platform.openai.com/docs/models)",
158
  on_change=write_data, key='model' + current_chat)
159
  st.slider("Temperature", 0.0, 2.0, st.session_state["temperature" + current_chat + "default"], 0.1,
160
  help="""在0和2之间,应该使用什么样的采样温度?较高的值(如0.8)会使输出更随机,而较低的值(如0.2)则会使其更加集中和确定性。
@@ -190,11 +224,16 @@ with tap_input:
190
  write_data(new_name)
191
 
192
 
193
-
194
  with st.form("input_form", clear_on_submit=True):
195
  user_input = st.text_area("**输入:**", key="user_input_area")
196
  submitted = st.form_submit_button("确认提交", use_container_width=True, on_click=input_callback)
197
- if submitted:
 
 
 
 
 
 
198
  st.session_state['user_input_content'] = user_input
199
 
200
  if st.session_state['user_input_content'] != '':
@@ -276,4 +315,4 @@ if ("r" in st.session_state) and (current_chat == st.session_state["chat_of_r"])
276
  st.session_state.pop("r")
277
 
278
  # 添加事件监听
279
- v1.html(js_code, height=0)
 
7
  from requests.models import ChunkedEncodingError
8
  from streamlit.components import v1
9
  from custom import css_code, js_code, set_context_all
10
+ from st_audiorec import st_audiorec
11
+
12
+ device = "cpu"
13
+
14
+ # STT
15
+ import whisper
16
+ WHISPER_LANG = "en" # detecting language if None
17
+ warnings.filterwarnings("ignore")
18
+ WHISPER_MODEL = whisper.load_model("base")
19
+ WHISPER_MODEL.to(device)
20
+
21
+ def transcribe(aud_inp):
22
+ if aud_inp is None:
23
+ return ""
24
+ if isinstance(aud_inp, str):
25
+ aud = whisper.load_audio(aud_inp)
26
+ elif isinstance(aud_inp, bytes): # if st_audiorec
27
+ aud = np.frombuffer(wav_bytes, dtype=np.uint8).flatten().astype(np.float32) / 255.0
28
+ aud = whisper.pad_or_trim(aud)
29
+ mel = whisper.log_mel_spectrogram(aud).to(device)
30
+ _, probs = WHISPER_MODEL.detect_language(mel)
31
+
32
+ if device == "cpu":
33
+ options = whisper.DecodingOptions(fp16 = False, language=WHISPER_LANG)
34
+ else:
35
+ options = whisper.DecodingOptions(language=WHISPER_LANG)
36
+
37
+ result = whisper.decode(WHISPER_MODEL, mel, options)
38
+ print("result.text", result.text)
39
+ result_text = ""
40
+ if result and result.text:
41
+ result_text = result.text
42
+ return result_text
43
+
44
 
45
  st.set_page_config(page_title='ChatGPT Assistant', layout='wide', page_icon='🤖')
46
  # 自定义元素样式
 
187
  key='context_level' + current_chat, help="表示每次会话中包含的历史对话次数,预设内容不计算在内。")
188
 
189
  st.markdown("模型参数:")
190
+ st.selectbox("Model", ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"], index=0,
191
+ help="[模型选择参考](https://platform.openai.com/docs/models)",
192
  on_change=write_data, key='model' + current_chat)
193
  st.slider("Temperature", 0.0, 2.0, st.session_state["temperature" + current_chat + "default"], 0.1,
194
  help="""在0和2之间,应该使用什么样的采样温度?较高的值(如0.8)会使输出更随机,而较低的值(如0.2)则会使其更加集中和确定性。
 
224
  write_data(new_name)
225
 
226
 
 
227
  with st.form("input_form", clear_on_submit=True):
228
  user_input = st.text_area("**输入:**", key="user_input_area")
229
  submitted = st.form_submit_button("确认提交", use_container_width=True, on_click=input_callback)
230
+
231
+ wav_audio_data = st_audiorec()
232
+ if wav_audio_data is not None:
233
+ st.audio(wav_audio_data, format='audio/wav')
234
+ user_input = transcribe(wav_audio_data)
235
+
236
+ if submitted or wav_audio_data is not None:
237
  st.session_state['user_input_content'] = user_input
238
 
239
  if st.session_state['user_input_content'] != '':
 
315
  st.session_state.pop("r")
316
 
317
  # 添加事件监听
318
+ v1.html(js_code, height=0)
requirements.txt CHANGED
@@ -3,3 +3,5 @@ streamlit==1.19.0
3
  pandas==1.5.3
4
  requests==2.28.2
5
  altair<5
 
 
 
3
  pandas==1.5.3
4
  requests==2.28.2
5
  altair<5
6
+ streamlit-audiorec
7
+ git+https://github.com/openai/whisper.git