|
import torch |
|
import soundfile as sf |
|
from .config import model_mms_tts_eng, tokenizer_mms_tts_eng |
|
|
|
SAMPLING_RATE = 16000 |
|
|
|
class T2A: |
|
def __init__(self, input_text: str): |
|
self.inputs = tokenizer_mms_tts_eng(input_text, return_tensors="pt") |
|
|
|
def __call__(self): |
|
if self.inputs is not None: |
|
with torch.no_grad(): |
|
output_model = model_mms_tts_eng(**self.inputs) |
|
|
|
audio = output_model["audio"][0] |
|
|
|
with BytesIO() as buffer: |
|
sf.write(buffer, audio, SAMPLING_RATE, format='wav') |
|
output = buffer.getvalue() |
|
|
|
return output |
|
else: |
|
raise Exception("Input text is None. Please provide text") |