Spaces:
Runtime error
Runtime error
from huggingface_hub import InferenceClient | |
import gradio as gr | |
from transformers import pipeline | |
import torch | |
from TTS.api import TTS | |
import os | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
class AsrBot(): | |
def __init__(self, model_name, lang): | |
self.model_name = model_name | |
self.lang = lang | |
self.pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=self.model_name, | |
chunk_length_s=30, | |
device=0, | |
) | |
self.pipe.model.config.forced_decoder_ids = self.pipe.tokenizer.get_decoder_prompt_ids(language=self.lang, task="transcribe") | |
def call(self, file_path): | |
if file_path is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
text = self.pipe(file_path)["text"] | |
return text | |
class LlmBot(): | |
def __init__(self, model): | |
self.client = InferenceClient(model) | |
def format_prompt(self, message): | |
prompt = "<s>" | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def call(self, prompt, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = self.format_prompt(f"{system_prompt}, {prompt}") | |
stream = self.client.text_generation(formatted_prompt, **generate_kwargs, details=True, return_full_text=False) | |
return stream.generated_text | |
class TtsBot(): | |
def __init__(self, model): | |
self.model = TTS(model).to("cuda") | |
def call(self, text): | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda") | |
tts.tts_to_file(text=text, speaker_wav="./titty-sprinkles-101soundboards.mp3", language="en", file_path="./output.wav") | |
return "./output.wav" | |
asr_bot = AsrBot("openai/whisper-small", "en") | |
llm_bot = LlmBot("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
tts_bot = TtsBot("tts_models/multilingual/multi-dataset/xtts_v2") | |