| | import os |
| | import torch |
| | import numpy as np |
| | import librosa |
| | import soundfile as sf |
| | import traceback |
| | import base64 |
| | import io |
| | import wave |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from snac import SNAC |
| | from vllm import LLM, SamplingParams |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| |
|
| | |
| | self.START_OF_HUMAN = 128259 |
| | self.START_OF_TEXT = 128000 |
| | self.END_OF_TEXT = 128009 |
| | self.END_OF_HUMAN = 128260 |
| | self.START_OF_AI = 128261 |
| | self.START_OF_SPEECH = 128257 |
| | self.END_OF_SPEECH = 128258 |
| | self.END_OF_AI = 128262 |
| | self.AUDIO_TOKENS_START = 128266 |
| |
|
| | |
| | self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3) |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| |
|
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | try: |
| | self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
| | self.snac_model.to(self.device) |
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load SNAC model: {e}") |
| |
|
| | |
| | def encode_text(self, text): |
| | return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False) |
| |
|
| | def encode_audio(self, base64_audio_str): |
| | audio_bytes = base64.b64decode(base64_audio_str) |
| | audio_buffer = io.BytesIO(audio_bytes) |
| | waveform, sr = sf.read(audio_buffer, dtype='float32') |
| |
|
| | if waveform.ndim > 1: |
| | waveform = np.mean(waveform, axis=1) |
| | if sr != 24000: |
| | waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000) |
| | return self.tokenize_audio(waveform) |
| |
|
| | def format_text_block(self, text_ids): |
| | return [ |
| | torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64), |
| | torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64), |
| | text_ids, |
| | torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64), |
| | torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64) |
| | ] |
| |
|
| | def format_audio_block(self, audio_codes): |
| | return [ |
| | torch.tensor([[self.START_OF_AI]], dtype=torch.int64), |
| | torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64), |
| | torch.tensor([audio_codes], dtype=torch.int64), |
| | torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64), |
| | torch.tensor([[self.END_OF_AI]], dtype=torch.int64) |
| | ] |
| |
|
| | def enroll_user(self, enrollment_pairs): |
| | """ |
| | Parameters: |
| | - enrollment_pairs: List of tuples (text, audio_data), where audio_data is |
| | base64-encoded audio data |
| | Returns: |
| | - cloning_features (str): serialized enrollment data |
| | """ |
| | enrollment_data = [] |
| |
|
| | for text, base64_audio in enrollment_pairs: |
| | text_ids = self.encode_text(text).cpu() |
| | audio_codes = self.encode_audio(base64_audio) |
| | enrollment_data.append({ |
| | "text_ids": text_ids, |
| | "audio_codes": audio_codes |
| | }) |
| |
|
| | |
| | buffer = io.BytesIO() |
| | torch.save(enrollment_data, buffer) |
| | buffer.seek(0) |
| |
|
| | |
| | cloning_features = base64.b64encode(buffer.read()).decode('utf-8') |
| | return cloning_features |
| |
|
| | def prepare_audio_tokens_for_decoder(self, audio_codes_list): |
| | """ |
| | Given a list containing sequences of generated audio codes, do the following: |
| | 1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame) |
| | 2. Adjust token values to SNAC decoder's expected range |
| | """ |
| | modified_audio_codes_list = [] |
| | for audio_codes in audio_codes_list: |
| |
|
| | |
| | length = (audio_codes.size(0) // 7) * 7 |
| | trimmed = audio_codes[:length] |
| |
|
| | |
| | audio_codes = trimmed - self.AUDIO_TOKENS_START |
| |
|
| | |
| | modified_audio_codes_list.append(audio_codes) |
| |
|
| | return modified_audio_codes_list |
| |
|
| | |
| | def tokenize_audio(self, waveform): |
| | waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device) |
| |
|
| | with torch.inference_mode(): |
| | codes = self.snac_model.encode(waveform) |
| |
|
| | all_codes = [] |
| | for i in range(codes[0].shape[1]): |
| |
|
| | all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096)) |
| | all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096)) |
| | all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096)) |
| | all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096)) |
| | all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096)) |
| | all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096)) |
| | all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096)) |
| |
|
| | return all_codes |
| |
|
| | def preprocess(self, data): |
| |
|
| | |
| |
|
| | self.voice_cloning = data.get("clone", False) |
| | clone_on_the_fly = data.get("clone_on_the_fly", False) |
| |
|
| | |
| | target_text = data["inputs"] |
| | parameters = data.get("parameters", {}) |
| | cloning_features = data.get("cloning_features", None) |
| |
|
| | temperature = float(parameters.get("temperature", 0.6)) |
| | top_p = float(parameters.get("top_p", 0.95)) |
| | max_new_tokens = int(parameters.get("max_new_tokens", 1200)) |
| | repetition_penalty = float(parameters.get("repetition_penalty", 1.1)) |
| |
|
| | if self.voice_cloning: |
| | if clone_on_the_fly: |
| | |
| | enrollment_pairs = data.get("enrollments", []) |
| | enrollment_data = [] |
| |
|
| | |
| | if not enrollment_pairs: |
| | raise ValueError("No enrollment pairs provided") |
| | |
| | for text, base64_audio in enrollment_pairs: |
| | text_ids = self.encode_text(text).cpu() |
| | audio_codes = self.encode_audio(base64_audio) |
| | enrollment_data.append({ |
| | "text_ids": text_ids, |
| | "audio_codes": audio_codes |
| | }) |
| |
|
| | elif not cloning_features: |
| | raise ValueError("No cloning features were provided") |
| | else: |
| | |
| | enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features))) |
| |
|
| | |
| | input_sequence = [] |
| | for item in enrollment_data: |
| | text_ids = item["text_ids"] |
| | audio_codes = item["audio_codes"] |
| | input_sequence.extend(self.format_text_block(text_ids)) |
| | input_sequence.extend(self.format_audio_block(audio_codes)) |
| |
|
| | |
| | target_text_ids = self.encode_text(target_text) |
| | input_sequence.extend(self.format_text_block(target_text_ids)) |
| |
|
| | |
| | input_sequence.extend([ |
| | torch.tensor([[self.START_OF_AI]], dtype=torch.int64), |
| | torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64) |
| | ]) |
| |
|
| | |
| | input_ids = torch.cat(input_sequence, dim=1) |
| |
|
| | |
| | attention_mask = torch.ones_like(input_ids) |
| | input_ids = input_ids.to(self.device) |
| | attention_mask = attention_mask.to(self.device) |
| |
|
| | else: |
| | |
| |
|
| | |
| | voice = parameters.get("voice", "Eniola") |
| | prompt = f"{voice}: {target_text}" |
| | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids |
| |
|
| | |
| | input_ids = torch.cat(self.format_text_block(input_ids), dim=1) |
| |
|
| | |
| | input_ids = input_ids.to(self.device) |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "temperature": temperature, |
| | "top_p": top_p, |
| | "max_new_tokens": max_new_tokens, |
| | "repetition_penalty": repetition_penalty, |
| | } |
| |
|
| | def inference(self, inputs): |
| | """ |
| | Run model inference on the preprocessed inputs |
| | """ |
| | |
| | input_ids = inputs["input_ids"] |
| |
|
| | sampling_params = SamplingParams( |
| | temperature = inputs["temperature"], |
| | top_p = inputs["top_p"], |
| | max_tokens = inputs["max_new_tokens"], |
| | repetition_penalty = inputs["repetition_penalty"], |
| | stop_token_ids = [self.END_OF_SPEECH], |
| | ) |
| | |
| | prompt_string = self.tokenizer.decode(input_ids[0]) |
| |
|
| | |
| | generated_ids = self.model.generate(prompt_string, sampling_params) |
| |
|
| | |
| | return { |
| | "gen_ids": torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0), |
| | "input_ids": input_ids |
| | } |
| |
|
| | def __call__(self, data): |
| | |
| | |
| |
|
| | try: |
| | enroll_user = data.get("enroll_user", False) |
| |
|
| | if enroll_user: |
| | |
| | enrollment_pairs = data.get("enrollments", []) |
| | cloning_features = self.enroll_user(enrollment_pairs) |
| | return {"cloning_features": cloning_features} |
| | else: |
| | |
| | preprocessed_inputs = self.preprocess(data) |
| | model_outputs = self.inference(preprocessed_inputs) |
| | response = self.postprocess(model_outputs) |
| | return response |
| |
|
| | |
| | except Exception as e: |
| | traceback.print_exc() |
| | return {"error": str(e)} |
| |
|
| | |
| | def convert_codes_to_waveform(self, code_list): |
| | """ |
| | Reorganize tokens for SNAC decoding |
| | """ |
| | layer_1 = [] |
| | layer_2 = [] |
| | layer_3 = [] |
| |
|
| | num_groups = len(code_list) // 7 |
| | for i in range(num_groups): |
| | idx = 7 * i |
| | layer_1.append(code_list[7 * i + 0] - (0 * 4096)) |
| | layer_2.append(code_list[7 * i + 1] - (1 * 4096)) |
| | layer_3.append(code_list[7 * i + 2] - (2 * 4096)) |
| | layer_3.append(code_list[7 * i + 3] - (3 * 4096)) |
| | layer_2.append(code_list[7 * i + 4] - (4 * 4096)) |
| | layer_3.append(code_list[7 * i + 5] - (5 * 4096)) |
| | layer_3.append(code_list[7 * i + 6] - (6 * 4096)) |
| |
|
| | codes = [ |
| | torch.tensor(layer_1).unsqueeze(0).to(self.device), |
| | torch.tensor(layer_2).unsqueeze(0).to(self.device), |
| | torch.tensor(layer_3).unsqueeze(0).to(self.device) |
| | ] |
| |
|
| | |
| | audio_hat = self.snac_model.decode(codes) |
| | return audio_hat |
| |
|
| | def postprocess(self, model_outputs): |
| |
|
| | generated_ids = model_outputs["gen_ids"] |
| | input_ids = model_outputs["input_ids"] |
| |
|
| | if self.voice_cloning: |
| | """ |
| | For cloning applications, use this postprocess function to get generated audio samples |
| | """ |
| | |
| | code_lists = self.prepare_audio_tokens_for_decoder(generated_ids) |
| |
|
| | |
| | temp = self.convert_codes_to_waveform(code_lists[0]) |
| | audio_sample = temp.detach().squeeze().to("cpu").numpy() |
| |
|
| | else: |
| | """ |
| | Process generated tokens into audio |
| | """ |
| | |
| | token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True) |
| |
|
| | if len(token_indices[1]) > 0: |
| | last_occurrence_idx = token_indices[1][-1].item() |
| | cropped_tensor = generated_ids[:, last_occurrence_idx+1:] |
| | else: |
| | cropped_tensor = generated_ids |
| |
|
| | |
| | processed_rows = [] |
| | for row in cropped_tensor: |
| | masked_row = row[row != self.END_OF_SPEECH] |
| | processed_rows.append(masked_row) |
| |
|
| | code_lists = self.prepare_audio_tokens_for_decoder(processed_rows) |
| |
|
| | |
| | audio_samples = [] |
| | for code_list in code_lists: |
| | if len(code_list) > 0: |
| | audio = self.convert_codes_to_waveform(code_list) |
| | audio_samples.append(audio) |
| | else: |
| | raise ValueError("Empty code list, no audio to generate") |
| |
|
| | if not audio_samples: |
| | return {"error": "No audio samples generated"} |
| |
|
| | |
| | audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() |
| |
|
| | |
| | audio_int16 = (audio_sample * 32767).astype(np.int16) |
| |
|
| | |
| | buffer = io.BytesIO() |
| | sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') |
| | buffer.seek(0) |
| |
|
| | |
| | audio_b64 = base64.b64encode(buffer.read()).decode('utf-8') |
| |
|
| | return { |
| | "audio_sample": audio_sample, |
| | "audio_b64": audio_b64, |
| | "sample_rate": 24000, |
| | "input_ids_len": input_ids.shape[1], |
| | "gen_ids_len": generated_ids.shape[1] |
| | } |