import torch from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer, set_seed import soundfile as sf import base64 import logging logger = logging.getLogger() logger.setLevel(logging.DEBUG) class EndpointHandler: def __init__(self, path=""): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.model = ParlerTTSForConditionalGeneration.from_pretrained( "parler-tts/parler-tts-mini-expresso", torch_dtype=torch.float16 ).to(self.device) # self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True) self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") def __call__(self, data): inputs = data["inputs"] prompt = inputs["prompt"] description = inputs["description"] input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device) prompt_input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) set_seed(42) generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() sf.write("parler_tts_out.wav", audio_arr, self.model.config.sampling_rate) with open("parler_tts_out.wav", "rb") as f: return base64.b64encode(f.read()).decode()