speecht5_tts / handler.py
Dupaja's picture
Revert direct file send.
d16558b
raw
history blame
1.83 kB
import librosa
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
#import soundfile as sf
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path=""):
#checkpoint = "microsoft/speecht5_tts"
#vocoder_id = "microsoft/speecht5_hifigan"
#dataset_id = "Matthijs/cmu-arctic-xvectors"
checkpoint = "Dupaja/speecht5_tts"
vocoder_id = "Dupaja/speecht5_hifigan"
dataset_id = "Dupaja/cmu-arctic-xvectors"
self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
self.processor = SpeechT5Processor.from_pretrained(checkpoint)
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True)
self.embeddings_dataset = embeddings_dataset
self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
given_text = data.get("inputs", "")
inputs = self.processor(text=given_text, return_tensors="pt")
speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
#filename = "current_sample.wav"
# Write the response audio to a file
#sf.write(filename, speech.numpy(), samplerate=16000)
# Return the expected response format
return {
"statusCode": 200,
"body": {
"audio": speech.numpy(), # Consider encoding this to a suitable format
"sampling_rate": 16000
}
}
handler = EndpointHandler()