File size: 2,411 Bytes
958dd28
959ccd3
a409194
efef3cb
 
0af484e
 
 
 
26e5292
a4f86ca
e01d276
26e5292
e01d276
a409194
 
 
d14b541
a409194
 
 
 
 
d14b541
 
a409194
 
 
 
 
 
26e5292
958dd28
 
 
 
518e42a
958dd28
 
 
91d3606
518e42a
958dd28
 
 
 
 
 
 
 
 
 
 
 
 
 
518e42a
f468151
72ed6a0
958dd28
91d3606
26e5292
efef3cb
96c2a49
 
efef3cb
 
 
 
a4f86ca
 
a409194
958dd28
96c2a49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from huggingface_hub import InferenceClient
import gradio as gr
from transformers import pipeline
import torch
from TTS.api import TTS
import os


os.environ["COQUI_TOS_AGREED"] = "1"


class AsrBot():

    def __init__(self, model_name, lang):
        self.model_name = model_name
        self.lang = lang
    
        self.pipe = pipeline(
            task="automatic-speech-recognition",
            model=self.model_name,
            chunk_length_s=30,
            device=0,
        )
        self.pipe.model.config.forced_decoder_ids = self.pipe.tokenizer.get_decoder_prompt_ids(language=self.lang, task="transcribe")
        
    def call(self, file_path):
        if file_path is None:
            raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    
        text = self.pipe(file_path)["text"]
        return text

class LlmBot():
    def __init__(self, model):
        self.client = InferenceClient(model)
        
    def format_prompt(self, message):
          prompt = "<s>"
          prompt += f"[INST] {message} [/INST]"
          return prompt
    
    def call(self, prompt, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
        temperature = float(temperature)
        if temperature < 1e-2:
            temperature = 1e-2
        top_p = float(top_p)
    
        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            seed=42,
        )
    
        formatted_prompt = self.format_prompt(f"{system_prompt}, {prompt}")
        stream = self.client.text_generation(formatted_prompt, **generate_kwargs, details=True, return_full_text=False)
        return stream.generated_text
        
        

class TtsBot():
    def __init__(self, model):
        self.model =  TTS(model).to("cuda")
    def call(self, text):
        tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda")
        tts.tts_to_file(text=text, speaker_wav="./titty-sprinkles-101soundboards.mp3", language="en", file_path="./output.wav")
        return "./output.wav"


asr_bot = AsrBot("openai/whisper-small", "en")
llm_bot = LlmBot("mistralai/Mixtral-8x7B-Instruct-v0.1")
tts_bot = TtsBot("tts_models/multilingual/multi-dataset/xtts_v2")