|
from __future__ import annotations |
|
|
|
import os |
|
|
|
os.environ["COQUI_TOS_AGREED"] = "1" |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import nltk |
|
nltk.download('punkt') |
|
import uuid |
|
|
|
import ffmpeg |
|
import librosa |
|
import torchaudio |
|
from TTS.api import TTS |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
from TTS.utils.generic_utils import get_user_data_dir |
|
|
|
|
|
print("Downloading if not downloaded Coqui XTTS V1") |
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1") |
|
del tts |
|
print("XTTS downloaded") |
|
|
|
print("Loading XTTS") |
|
|
|
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") |
|
config = XttsConfig() |
|
config.load_json(os.path.join(model_path, "config.json")) |
|
model = Xtts.init_from_config(config) |
|
model.load_checkpoint( |
|
config, |
|
checkpoint_path=os.path.join(model_path, "model.pth"), |
|
vocab_path=os.path.join(model_path, "vocab.json"), |
|
eval=True, |
|
use_deepspeed=True |
|
) |
|
model.cuda() |
|
print("Done loading TTS") |
|
|
|
|
|
title = "Voice chat with Mistral 7B Instruct" |
|
|
|
DESCRIPTION = """# Voice chat with Mistral 7B Instruct""" |
|
css = """.toast-wrap { display: none !important } """ |
|
|
|
from huggingface_hub import HfApi |
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
api = HfApi(token=HF_TOKEN) |
|
|
|
repo_id = "ylacombe/voice-chat-with-lama" |
|
|
|
system_message = "\nYou are a helpful, respectful and honest assistant. Your answers are short, ideally a few words long, if it is possible. Always answer as helpfully as possible, while being safe.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." |
|
temperature = 0.9 |
|
top_p = 0.6 |
|
repetition_penalty = 1.2 |
|
|
|
|
|
import gradio as gr |
|
import os |
|
import time |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
import numpy as np |
|
|
|
from gradio_client import Client |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
|
|
|
whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space") |
|
text_client = InferenceClient( |
|
"mistralai/Mistral-7B-Instruct-v0.1" |
|
) |
|
|
|
|
|
def get_latents(speaker_wav): |
|
|
|
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) |
|
return gpt_cond_latent, diffusion_conditioning, speaker_embedding |
|
|
|
|
|
def format_prompt(message, history): |
|
prompt = "<s>" |
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
prompt += f"[INST] {message} [/INST]" |
|
return prompt |
|
|
|
def generate( |
|
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, |
|
): |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
formatted_prompt = format_prompt(prompt, history) |
|
|
|
try: |
|
stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
for response in stream: |
|
output += response.token.text |
|
yield output |
|
|
|
except Exception as e: |
|
if "Too Many Requests" in str(e): |
|
print("ERROR: Too many requests on mistral client") |
|
gr.Warning("Unfortunately Mistral is unable to process") |
|
output = "Unfortuanately I am not able to process your request now !" |
|
else: |
|
print("Unhandled Exception: ", str(e)) |
|
gr.Warning("Unfortunately Mistral is unable to process") |
|
output = "I do not know what happened but I could not understand you ." |
|
|
|
return output |
|
|
|
|
|
def transcribe(wav_path): |
|
|
|
|
|
return whisper_client.predict( |
|
wav_path, |
|
"transcribe", |
|
False, |
|
api_name="/predict" |
|
)[0].strip() |
|
|
|
|
|
|
|
|
|
|
|
def add_text(history, text): |
|
history = [] if history is None else history |
|
history = history + [(text, None)] |
|
return history, gr.update(value="", interactive=False) |
|
|
|
|
|
def add_file(history, file): |
|
history = [] if history is None else history |
|
|
|
try: |
|
text = transcribe( |
|
file |
|
) |
|
print("Transcribed text:",text) |
|
except Exception as e: |
|
print(str(e)) |
|
gr.Warning("There was an issue with transcription, please try writing for now") |
|
|
|
text = "Transcription seems failed, please tell me a joke about chickens" |
|
|
|
history = history + [(text, None)] |
|
return history |
|
|
|
|
|
|
|
def bot(history, system_prompt=""): |
|
history = [] if history is None else history |
|
|
|
if system_prompt == "": |
|
system_prompt = system_message |
|
|
|
history[-1][1] = "" |
|
for character in generate(history[-1][0], history[:-1]): |
|
history[-1][1] = character |
|
yield history |
|
|
|
|
|
def get_latents(speaker_wav): |
|
|
|
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) |
|
return gpt_cond_latent, diffusion_conditioning, speaker_embedding |
|
|
|
latent_map={} |
|
latent_map["Female_Voice"] = get_latents("examples/female.wav") |
|
|
|
def get_voice(prompt,language, latent_tuple,suffix="0"): |
|
gpt_cond_latent,diffusion_conditioning, speaker_embedding = latent_tuple |
|
|
|
t0 = time.time() |
|
out = model.inference( |
|
prompt, |
|
language, |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
diffusion_conditioning |
|
) |
|
inference_time = time.time() - t0 |
|
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds") |
|
real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000 |
|
print(f"Real-time factor (RTF): {real_time_factor}") |
|
wav_filename=f"output_{suffix}.wav" |
|
torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000) |
|
return wav_filename |
|
|
|
def generate_speech(history): |
|
text_to_generate = history[-1][1] |
|
text_to_generate = text_to_generate.replace("\n", " ").strip() |
|
text_to_generate = nltk.sent_tokenize(text_to_generate) |
|
|
|
language = "en" |
|
|
|
wav_list = [] |
|
for i,sentence in enumerate(text_to_generate): |
|
|
|
sentence= sentence.replace("</s>","") |
|
|
|
if sentence[-1] in ["!","?",".",","]: |
|
|
|
sentence = sentence[:-1] + " " + sentence[-1] |
|
|
|
print("Sentence:", sentence) |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=i) |
|
wav_list.append(wav) |
|
|
|
yield wav |
|
wait_time= librosa.get_duration(path=wav) |
|
print("Sleeping till audio end") |
|
time.sleep(wait_time) |
|
|
|
except RuntimeError as e : |
|
if "device-side assert" in str(e): |
|
|
|
print(f"Exit due to: Unrecoverable exception caused by prompt:{sentence}", flush=True) |
|
gr.Warning("Unhandled Exception encounter, please retry in a minute") |
|
print("Cuda device-assert Runtime encountered need restart") |
|
|
|
|
|
|
|
api.restart_space(repo_id=repo_id) |
|
else: |
|
print("RuntimeError: non device-side assert error:", str(e)) |
|
raise e |
|
|
|
|
|
files_to_concat= [ffmpeg.input(w) for w in wav_list] |
|
combined_file_name="combined.wav" |
|
ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True) |
|
|
|
return gr.Audio.update(value=combined_file_name, autoplay=False) |
|
|
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'), |
|
bubble_full_width=False, |
|
) |
|
|
|
with gr.Row(): |
|
txt = gr.Textbox( |
|
scale=3, |
|
show_label=False, |
|
placeholder="Enter text and press enter, or speak to your microphone", |
|
container=False, |
|
) |
|
txt_btn = gr.Button(value="Submit text",scale=1) |
|
btn = gr.Audio(source="microphone", type="filepath", scale=4) |
|
|
|
with gr.Row(): |
|
audio = gr.Audio(type="numpy", streaming=False, autoplay=True, label="Generated audio response", show_label=True) |
|
|
|
clear_btn = gr.ClearButton([chatbot, audio]) |
|
|
|
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
bot, chatbot, chatbot |
|
).then(generate_speech, chatbot, audio) |
|
|
|
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) |
|
|
|
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
bot, chatbot, chatbot |
|
).then(generate_speech, chatbot, audio) |
|
|
|
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) |
|
|
|
file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
).then(generate_speech, chatbot, audio) |
|
|
|
|
|
gr.Markdown(""" |
|
This Space demonstrates how to speak to a chatbot, based solely on open-source models. |
|
It relies on 3 models: |
|
1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client). |
|
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference). |
|
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally. |
|
|
|
Note: |
|
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""") |
|
demo.queue() |
|
demo.launch(debug=True) |
|
|