ylacombe's picture
ylacombe HF staff
xtts-and-whisper-update (#3)
c534b30
raw
history blame
11.8 kB
from __future__ import annotations
import os
# By using XTTS you agree to CPML license https://coqui.ai/cpml
os.environ["COQUI_TOS_AGREED"] = "1"
import gradio as gr
import numpy as np
import torch
import nltk # we'll use this to split into sentences
nltk.download('punkt')
import uuid
import ffmpeg
import librosa
import torchaudio
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
# This will trigger downloading model
print("Downloading if not downloaded Coqui XTTS V1")
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1")
del tts
print("XTTS downloaded")
print("Loading XTTS")
#Below will use model directly for inference
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_path=os.path.join(model_path, "model.pth"),
vocab_path=os.path.join(model_path, "vocab.json"),
eval=True,
use_deepspeed=True
)
model.cuda()
print("Done loading TTS")
title = "Voice chat with Mistral 7B Instruct"
DESCRIPTION = """# Voice chat with Mistral 7B Instruct"""
css = """.toast-wrap { display: none !important } """
from huggingface_hub import HfApi
HF_TOKEN = os.environ.get("HF_TOKEN")
# will use api to restart space on a unrecoverable error
api = HfApi(token=HF_TOKEN)
repo_id = "ylacombe/voice-chat-with-lama"
system_message = "\nYou are a helpful, respectful and honest assistant. Your answers are short, ideally a few words long, if it is possible. Always answer as helpfully as possible, while being safe.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
temperature = 0.9
top_p = 0.6
repetition_penalty = 1.2
import gradio as gr
import os
import time
import gradio as gr
from transformers import pipeline
import numpy as np
from gradio_client import Client
from huggingface_hub import InferenceClient
# This client is down
#whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
# Replacement whisper client, it may be time limited
whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space")
text_client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
###### COQUI TTS FUNCTIONS ######
def get_latents(speaker_wav):
# create as function as we can populate here with voice cleanup/filtering
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
return gpt_cond_latent, diffusion_conditioning, speaker_embedding
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
try:
stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on mistral client")
gr.Warning("Unfortunately Mistral is unable to process")
output = "Unfortuanately I am not able to process your request now !"
else:
print("Unhandled Exception: ", str(e))
gr.Warning("Unfortunately Mistral is unable to process")
output = "I do not know what happened but I could not understand you ."
return output
def transcribe(wav_path):
# get first element from whisper_jax and strip it to delete begin and end space
return whisper_client.predict(
wav_path, # str (filepath or URL to file) in 'inputs' Audio component
"transcribe", # str in 'Task' Radio component
False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py
api_name="/predict"
)[0].strip()
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def add_file(history, file):
history = [] if history is None else history
try:
text = transcribe(
file
)
print("Transcribed text:",text)
except Exception as e:
print(str(e))
gr.Warning("There was an issue with transcription, please try writing for now")
# Apply a null text on error
text = "Transcription seems failed, please tell me a joke about chickens"
history = history + [(text, None)]
return history
def bot(history, system_prompt=""):
history = [] if history is None else history
if system_prompt == "":
system_prompt = system_message
history[-1][1] = ""
for character in generate(history[-1][0], history[:-1]):
history[-1][1] = character
yield history
def get_latents(speaker_wav):
# Generate speaker embedding and latents for TTS
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
return gpt_cond_latent, diffusion_conditioning, speaker_embedding
latent_map={}
latent_map["Female_Voice"] = get_latents("examples/female.wav")
def get_voice(prompt,language, latent_tuple,suffix="0"):
gpt_cond_latent,diffusion_conditioning, speaker_embedding = latent_tuple
# Direct version
t0 = time.time()
out = model.inference(
prompt,
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning
)
inference_time = time.time() - t0
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000
print(f"Real-time factor (RTF): {real_time_factor}")
wav_filename=f"output_{suffix}.wav"
torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000)
return wav_filename
def generate_speech(history):
text_to_generate = history[-1][1]
text_to_generate = text_to_generate.replace("\n", " ").strip()
text_to_generate = nltk.sent_tokenize(text_to_generate)
language = "en"
wav_list = []
for i,sentence in enumerate(text_to_generate):
# Sometimes prompt </s> coming on output remove it
sentence= sentence.replace("</s>","")
# A fast fix for last chacter, may produce weird sounds if it is with text
if sentence[-1] in ["!","?",".",","]:
#just add a space
sentence = sentence[:-1] + " " + sentence[-1]
print("Sentence:", sentence)
try:
# generate speech using precomputed latents
# This is not streaming but it will be fast
# giving sentence suffix so we can merge all to single audio at end
# On mobile there is no autoplay support due to mobile security!
wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=i)
wav_list.append(wav)
yield wav
wait_time= librosa.get_duration(path=wav)
print("Sleeping till audio end")
time.sleep(wait_time)
except RuntimeError as e :
if "device-side assert" in str(e):
# cannot do anything on cuda device side error, need tor estart
print(f"Exit due to: Unrecoverable exception caused by prompt:{sentence}", flush=True)
gr.Warning("Unhandled Exception encounter, please retry in a minute")
print("Cuda device-assert Runtime encountered need restart")
# HF Space specific.. This error is unrecoverable need to restart space
api.restart_space(repo_id=repo_id)
else:
print("RuntimeError: non device-side assert error:", str(e))
raise e
#Spoken on autoplay everysencen now produce a concataned one at the one
#requires pip install ffmpeg-python
files_to_concat= [ffmpeg.input(w) for w in wav_list]
combined_file_name="combined.wav"
ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True)
return gr.Audio.update(value=combined_file_name, autoplay=False)
with gr.Blocks(title=title) as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'),
bubble_full_width=False,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter, or speak to your microphone",
container=False,
)
txt_btn = gr.Button(value="Submit text",scale=1)
btn = gr.Audio(source="microphone", type="filepath", scale=4)
with gr.Row():
audio = gr.Audio(type="numpy", streaming=False, autoplay=True, label="Generated audio response", show_label=True)
clear_btn = gr.ClearButton([chatbot, audio])
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, chatbot
).then(generate_speech, chatbot, audio)
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, chatbot
).then(generate_speech, chatbot, audio)
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then(
bot, chatbot, chatbot
).then(generate_speech, chatbot, audio)
gr.Markdown("""
This Space demonstrates how to speak to a chatbot, based solely on open-source models.
It relies on 3 models:
1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
Note:
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""")
demo.queue()
demo.launch(debug=True)