test-spase / models.py
taras5500's picture
Update models.py
72ed6a0 verified
raw
history blame contribute delete
No virus
2.41 kB
from huggingface_hub import InferenceClient
import gradio as gr
from transformers import pipeline
import torch
from TTS.api import TTS
import os
os.environ["COQUI_TOS_AGREED"] = "1"
class AsrBot():
def __init__(self, model_name, lang):
self.model_name = model_name
self.lang = lang
self.pipe = pipeline(
task="automatic-speech-recognition",
model=self.model_name,
chunk_length_s=30,
device=0,
)
self.pipe.model.config.forced_decoder_ids = self.pipe.tokenizer.get_decoder_prompt_ids(language=self.lang, task="transcribe")
def call(self, file_path):
if file_path is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
text = self.pipe(file_path)["text"]
return text
class LlmBot():
def __init__(self, model):
self.client = InferenceClient(model)
def format_prompt(self, message):
prompt = "<s>"
prompt += f"[INST] {message} [/INST]"
return prompt
def call(self, prompt, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = self.format_prompt(f"{system_prompt}, {prompt}")
stream = self.client.text_generation(formatted_prompt, **generate_kwargs, details=True, return_full_text=False)
return stream.generated_text
class TtsBot():
def __init__(self, model):
self.model = TTS(model).to("cuda")
def call(self, text):
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda")
tts.tts_to_file(text=text, speaker_wav="./titty-sprinkles-101soundboards.mp3", language="en", file_path="./output.wav")
return "./output.wav"
asr_bot = AsrBot("openai/whisper-small", "en")
llm_bot = LlmBot("mistralai/Mixtral-8x7B-Instruct-v0.1")
tts_bot = TtsBot("tts_models/multilingual/multi-dataset/xtts_v2")