import torch from datasets import load_dataset from transformers import Pipeline, SpeechT5Processor, SpeechT5HifiGan class TTSPipeline(Pipeline): def __init__(self, *args, vocoder=None, processor=None, **kwargs): super().__init__(*args, **kwargs) if vocoder is None: raise ValueError("Must pass a vocoder to the TTSPipeline.") if processor is None: raise ValueError("Must pass a processor to the TTSPipeline.") if isinstance(vocoder, str): vocoder = SpeechT5HifiGan.from_pretrained(vocoder) if isinstance(processor, str): processor = SpeechT5Processor.from_pretrained(processor) self.processor = processor self.vocoder = vocoder def preprocess(self, text, speaker_embeddings=None): inputs = self.processor(text=text, return_tensors='pt') if speaker_embeddings is None: embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0) return {'inputs': inputs, 'speaker_embeddings': speaker_embeddings} def _forward(self, model_inputs): inputs = model_inputs['inputs'] speaker_embeddings = model_inputs['speaker_embeddings'] with torch.no_grad(): speech = self.model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=self.vocoder) return speech def _sanitize_parameters(self, **pipeline_parameters): return {}, {}, {} def postprocess(self, speech): return speech