Improve UX a bit and switch back to Whisper large v2
Browse files
app.py
CHANGED
@@ -19,7 +19,6 @@ import datetime
|
|
19 |
|
20 |
from scipy.io.wavfile import write
|
21 |
from pydub import AudioSegment
|
22 |
-
import ffmpeg
|
23 |
|
24 |
import re
|
25 |
import io, wave
|
@@ -57,7 +56,7 @@ model.load_checkpoint(
|
|
57 |
checkpoint_path=os.path.join(model_path, "model.pth"),
|
58 |
vocab_path=os.path.join(model_path, "vocab.json"),
|
59 |
eval=True,
|
60 |
-
use_deepspeed=True
|
61 |
)
|
62 |
model.cuda()
|
63 |
print("Done loading TTS")
|
@@ -113,10 +112,7 @@ from gradio_client import Client
|
|
113 |
from huggingface_hub import InferenceClient
|
114 |
|
115 |
WHISPER_TIMEOUT = int(os.environ.get("WHISPER_TIMEOUT", 30))
|
116 |
-
|
117 |
-
# whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
|
118 |
-
# Replacement whisper client, it may be time limited
|
119 |
-
whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space")
|
120 |
text_client = InferenceClient(
|
121 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
122 |
timeout=WHISPER_TIMEOUT,
|
@@ -203,13 +199,12 @@ def generate(
|
|
203 |
|
204 |
def transcribe(wav_path):
|
205 |
try:
|
206 |
-
# get
|
207 |
return whisper_client.predict(
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
)[0].strip()
|
213 |
except:
|
214 |
gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
|
215 |
return "There was a problem with my voice, tell me joke"
|
@@ -242,8 +237,8 @@ def add_file(history, file):
|
|
242 |
|
243 |
##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
|
244 |
def bot(history, system_prompt=""):
|
245 |
-
history = [] if history is None else history
|
246 |
-
|
247 |
if system_prompt == "":
|
248 |
system_prompt = system_message
|
249 |
|
@@ -267,21 +262,6 @@ latent_map = {}
|
|
267 |
latent_map["Female_Voice"] = get_latents("examples/female.wav")
|
268 |
|
269 |
|
270 |
-
def get_voice(prompt, language, latent_tuple, suffix="0"):
|
271 |
-
gpt_cond_latent, diffusion_conditioning, speaker_embedding = latent_tuple
|
272 |
-
# Direct version
|
273 |
-
t0 = time.time()
|
274 |
-
out = model.inference(
|
275 |
-
prompt, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning
|
276 |
-
)
|
277 |
-
inference_time = time.time() - t0
|
278 |
-
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
|
279 |
-
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
|
280 |
-
print(f"Real-time factor (RTF): {real_time_factor}")
|
281 |
-
wav_filename = f"output_{suffix}.wav"
|
282 |
-
torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
283 |
-
return wav_filename
|
284 |
-
|
285 |
|
286 |
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
|
287 |
# This will create a wave header then append the frame input
|
@@ -333,7 +313,7 @@ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
|
|
333 |
if "device-side assert" in str(e):
|
334 |
# cannot do anything on cuda device side error, need tor estart
|
335 |
print(
|
336 |
-
f"Exit due to: Unrecoverable exception caused by prompt:{
|
337 |
flush=True,
|
338 |
)
|
339 |
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
@@ -353,10 +333,12 @@ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
|
|
353 |
|
354 |
def get_sentence(history, system_prompt=""):
|
355 |
history = [["", None]] if history is None else history
|
356 |
-
|
357 |
if system_prompt == "":
|
358 |
system_prompt = system_message
|
359 |
|
|
|
|
|
360 |
mistral_start = time.time()
|
361 |
print("Mistral start")
|
362 |
sentence_list = []
|
@@ -422,8 +404,8 @@ def generate_speech(history):
|
|
422 |
try:
|
423 |
# generate speech using precomputed latents
|
424 |
# This is not streaming but it will be fast
|
425 |
-
# wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=len(wav_list))
|
426 |
if len(sentence) > 250:
|
|
|
427 |
# should not generate voice it will hit token limit
|
428 |
# It should not generate audio for it
|
429 |
audio_stream = None
|
@@ -520,6 +502,7 @@ with gr.Blocks(title=title) as demo:
|
|
520 |
show_label=False,
|
521 |
placeholder="Enter text and press enter, or speak to your microphone",
|
522 |
container=False,
|
|
|
523 |
)
|
524 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
525 |
btn = gr.Audio(source="microphone", type="filepath", scale=4)
|
@@ -536,7 +519,7 @@ with gr.Blocks(title=title) as demo:
|
|
536 |
# final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
|
537 |
|
538 |
clear_btn = gr.ClearButton([chatbot, audio])
|
539 |
-
|
540 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
541 |
generate_speech, chatbot, [audio, chatbot]
|
542 |
)
|
@@ -553,13 +536,13 @@ with gr.Blocks(title=title) as demo:
|
|
553 |
add_file, [chatbot, btn], [chatbot, txt], queue=False
|
554 |
).then(generate_speech, chatbot, [audio, chatbot])
|
555 |
|
556 |
-
file_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
|
557 |
|
558 |
gr.Markdown(
|
559 |
"""
|
560 |
This Space demonstrates how to speak to a chatbot, based solely on open-source models.
|
561 |
It relies on 3 models:
|
562 |
-
1. [Whisper-large-v2](https://
|
563 |
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
|
564 |
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
|
565 |
|
@@ -567,4 +550,4 @@ Note:
|
|
567 |
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
|
568 |
)
|
569 |
demo.queue()
|
570 |
-
demo.launch(debug=True
|
|
|
19 |
|
20 |
from scipy.io.wavfile import write
|
21 |
from pydub import AudioSegment
|
|
|
22 |
|
23 |
import re
|
24 |
import io, wave
|
|
|
56 |
checkpoint_path=os.path.join(model_path, "model.pth"),
|
57 |
vocab_path=os.path.join(model_path, "vocab.json"),
|
58 |
eval=True,
|
59 |
+
use_deepspeed=False, # TODO: replace by True
|
60 |
)
|
61 |
model.cuda()
|
62 |
print("Done loading TTS")
|
|
|
112 |
from huggingface_hub import InferenceClient
|
113 |
|
114 |
WHISPER_TIMEOUT = int(os.environ.get("WHISPER_TIMEOUT", 30))
|
115 |
+
whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
|
|
|
|
|
|
|
116 |
text_client = InferenceClient(
|
117 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
118 |
timeout=WHISPER_TIMEOUT,
|
|
|
199 |
|
200 |
def transcribe(wav_path):
|
201 |
try:
|
202 |
+
# get result from whisper and strip it to delete begin and end space
|
203 |
return whisper_client.predict(
|
204 |
+
wav_path, # str (filepath or URL to file) in 'inputs' Audio component
|
205 |
+
"transcribe", # str in 'Task' Radio component
|
206 |
+
api_name="/predict"
|
207 |
+
).strip()
|
|
|
208 |
except:
|
209 |
gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
|
210 |
return "There was a problem with my voice, tell me joke"
|
|
|
237 |
|
238 |
##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
|
239 |
def bot(history, system_prompt=""):
|
240 |
+
history = [["", None]] if history is None else history
|
241 |
+
|
242 |
if system_prompt == "":
|
243 |
system_prompt = system_message
|
244 |
|
|
|
262 |
latent_map["Female_Voice"] = get_latents("examples/female.wav")
|
263 |
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
|
267 |
# This will create a wave header then append the frame input
|
|
|
313 |
if "device-side assert" in str(e):
|
314 |
# cannot do anything on cuda device side error, need tor estart
|
315 |
print(
|
316 |
+
f"Exit due to: Unrecoverable exception caused by prompt:{prompt}",
|
317 |
flush=True,
|
318 |
)
|
319 |
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
|
|
333 |
|
334 |
def get_sentence(history, system_prompt=""):
|
335 |
history = [["", None]] if history is None else history
|
336 |
+
|
337 |
if system_prompt == "":
|
338 |
system_prompt = system_message
|
339 |
|
340 |
+
history[-1][1] = ""
|
341 |
+
|
342 |
mistral_start = time.time()
|
343 |
print("Mistral start")
|
344 |
sentence_list = []
|
|
|
404 |
try:
|
405 |
# generate speech using precomputed latents
|
406 |
# This is not streaming but it will be fast
|
|
|
407 |
if len(sentence) > 250:
|
408 |
+
gr.Warning("There was a problem with the last sentence, which was too long, so it won't be spoken.")
|
409 |
# should not generate voice it will hit token limit
|
410 |
# It should not generate audio for it
|
411 |
audio_stream = None
|
|
|
502 |
show_label=False,
|
503 |
placeholder="Enter text and press enter, or speak to your microphone",
|
504 |
container=False,
|
505 |
+
interactive=True,
|
506 |
)
|
507 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
508 |
btn = gr.Audio(source="microphone", type="filepath", scale=4)
|
|
|
519 |
# final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
|
520 |
|
521 |
clear_btn = gr.ClearButton([chatbot, audio])
|
522 |
+
|
523 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
524 |
generate_speech, chatbot, [audio, chatbot]
|
525 |
)
|
|
|
536 |
add_file, [chatbot, btn], [chatbot, txt], queue=False
|
537 |
).then(generate_speech, chatbot, [audio, chatbot])
|
538 |
|
539 |
+
file_msg.then(lambda: (gr.update(interactive=True),gr.update(interactive=True,value=None)), None, [txt, btn], queue=False)
|
540 |
|
541 |
gr.Markdown(
|
542 |
"""
|
543 |
This Space demonstrates how to speak to a chatbot, based solely on open-source models.
|
544 |
It relies on 3 models:
|
545 |
+
1. [Whisper-large-v2](https://sanchit-gandhi-whisper-large-v2.hf.space/) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
|
546 |
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
|
547 |
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
|
548 |
|
|
|
550 |
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
|
551 |
)
|
552 |
demo.queue()
|
553 |
+
demo.launch(debug=True)
|