Spaces:
Runtime error
Runtime error
File size: 2,411 Bytes
958dd28 959ccd3 a409194 efef3cb 0af484e 26e5292 a4f86ca e01d276 26e5292 e01d276 a409194 d14b541 a409194 d14b541 a409194 26e5292 958dd28 518e42a 958dd28 91d3606 518e42a 958dd28 518e42a f468151 72ed6a0 958dd28 91d3606 26e5292 efef3cb 96c2a49 efef3cb a4f86ca a409194 958dd28 96c2a49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from huggingface_hub import InferenceClient
import gradio as gr
from transformers import pipeline
import torch
from TTS.api import TTS
import os
os.environ["COQUI_TOS_AGREED"] = "1"
class AsrBot():
def __init__(self, model_name, lang):
self.model_name = model_name
self.lang = lang
self.pipe = pipeline(
task="automatic-speech-recognition",
model=self.model_name,
chunk_length_s=30,
device=0,
)
self.pipe.model.config.forced_decoder_ids = self.pipe.tokenizer.get_decoder_prompt_ids(language=self.lang, task="transcribe")
def call(self, file_path):
if file_path is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
text = self.pipe(file_path)["text"]
return text
class LlmBot():
def __init__(self, model):
self.client = InferenceClient(model)
def format_prompt(self, message):
prompt = "<s>"
prompt += f"[INST] {message} [/INST]"
return prompt
def call(self, prompt, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = self.format_prompt(f"{system_prompt}, {prompt}")
stream = self.client.text_generation(formatted_prompt, **generate_kwargs, details=True, return_full_text=False)
return stream.generated_text
class TtsBot():
def __init__(self, model):
self.model = TTS(model).to("cuda")
def call(self, text):
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda")
tts.tts_to_file(text=text, speaker_wav="./titty-sprinkles-101soundboards.mp3", language="en", file_path="./output.wav")
return "./output.wav"
asr_bot = AsrBot("openai/whisper-small", "en")
llm_bot = LlmBot("mistralai/Mixtral-8x7B-Instruct-v0.1")
tts_bot = TtsBot("tts_models/multilingual/multi-dataset/xtts_v2")
|