Oysiyl's picture
Create handler.py
dbf0dc1 verified
raw
history blame
3.47 kB
from typing import Dict, List, Text, Any
import re
from transformers import SpeechT5ForTextToSpeech
from transformers import SpeechT5Processor
from transformers import SpeechT5HifiGan
import soundfile
import torch
import numpy as np
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler():
def __init__(self, path=""):
# Load all required models
self.model_id = "Oysiyl/speecht5_tts_common_voice_uk"
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_id, torch_dtype=dtype).to(device)
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
self.speaker_embeddings = torch.tensor(np.load("embed.npy"), dtype=dtype).to(device)
@staticmethod
def remove_special_characters_s(text: Text) -> Text:
chars_to_remove_regex = '[\…\–\"\“\%\‘\”\�\»\«\„\`\'́]'
# remove special characters
text = re.sub(chars_to_remove_regex, '', text)
text = re.sub("՚", "'", text)
text = re.sub("’", "'", text)
text = re.sub(r'ы', 'и', text)
text = text.lower()
return text
@staticmethod
def cyrillic_to_latin(text: Text) -> Text:
replacements = [
('а', 'a'),
('б', 'b'),
('в', 'v'),
('г', 'h'),
('д', 'd'),
('е', 'e'),
('ж', 'zh'),
('з', 'z'),
('и', 'y'),
('й', 'j'),
('к', 'k'),
('л', 'l'),
('м', 'm'),
('н', 'n'),
('о', 'o'),
('п', 'p'),
('р', 'r'),
('с', 's'),
('т', 't'),
('у', 'u'),
('ф', 'f'),
('х', 'h'),
('ц', 'ts'),
('ч', 'ch'),
('ш', 'sh'),
('щ', 'sch'),
('ь', "'"),
('ю', 'ju'),
('я', 'ja'),
('є', 'je'),
('і', 'i'),
('ї', 'ji'),
('ґ', 'g')
]
for src, dst in replacements:
text = text.replace(src, dst)
return text
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs`.
:return: A dictionary with `image` field contains image in base64.
"""
text = data.pop("inputs", None)
# Check if text is not provided
if text is None:
return {"error": "Please provide a text."}
# run inference pipeline
text = self.remove_special_characters_s(text)
text = self.cyrillic_to_latin(text)
input_ids = self.processor(text=text, return_tensors="pt")['input_ids'].to(device)
spectrogram = self.model.generate_speech(input_ids, self.speaker_embeddings)
with torch.no_grad():
speech = self.vocoder(spectrogram)
if device.type != 'cuda':
out = speech.numpy()
else:
out = speech.cpu().numpy()
# return output audio in numpy format
return out