|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
import onnxruntime |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
|
|
repo_id = "Athspi/Gg" |
|
onnx_filename = "mms_tts_eng.onnx" |
|
sampling_rate = 16000 |
|
|
|
|
|
onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename) |
|
print(f"ONNX model downloaded to (cache): {onnx_model_path}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
|
|
|
|
|
|
session_options = onnxruntime.SessionOptions() |
|
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
try: |
|
import psutil |
|
num_physical_cores = psutil.cpu_count(logical=False) |
|
except ImportError: |
|
print("psutil not installed. Install with: pip install psutil") |
|
num_physical_cores = 4 |
|
print(f"Using default: {num_physical_cores}") |
|
session_options.intra_op_num_threads = num_physical_cores |
|
session_options.inter_op_num_threads = 1 |
|
|
|
ort_session = onnxruntime.InferenceSession( |
|
onnx_model_path, |
|
providers=['CPUExecutionProvider'], |
|
sess_options=session_options, |
|
) |
|
|
|
|
|
io_binding = ort_session.io_binding() |
|
input_meta = ort_session.get_inputs()[0] |
|
output_meta = ort_session.get_outputs()[0] |
|
dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long) |
|
input_shape = tuple(dummy_input.shape) |
|
input_type = dummy_input.numpy().dtype |
|
input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous() |
|
max_output_length = input_shape[1] * 10 |
|
output_shape = (1, 1, max_output_length) |
|
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous() |
|
|
|
|
|
io_binding.bind_input( |
|
name=input_meta.name, device_type="cpu", device_id=0, |
|
element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(), |
|
) |
|
io_binding.bind_output( |
|
name=output_meta.name, device_type="cpu", device_id=0, |
|
element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(), |
|
) |
|
|
|
|
|
def tts_inference_io_binding(text: str): |
|
"""TTS inference with IO Binding.""" |
|
global input_tensor, output_tensor, io_binding |
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
input_ids = inputs.input_ids.to(torch.long) |
|
current_input_shape = tuple(input_ids.shape) |
|
|
|
|
|
if current_input_shape[1] > input_tensor.shape[1]: |
|
input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous() |
|
io_binding.bind_input( |
|
name=input_meta.name, device_type="cpu", device_id=0, |
|
element_type=input_type, shape=current_input_shape, |
|
buffer_ptr=input_tensor.data_ptr(), |
|
) |
|
|
|
|
|
input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids) |
|
|
|
|
|
|
|
required_output_length = current_input_shape[1] * 10 |
|
if required_output_length > output_tensor.shape[2]: |
|
output_shape = (1, 1, required_output_length) |
|
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous() |
|
io_binding.bind_output( |
|
name=output_meta.name, device_type="cpu", device_id=0, |
|
element_type=np.float32, shape=output_shape, |
|
buffer_ptr=output_tensor.data_ptr(), |
|
) |
|
|
|
|
|
io_binding.clear_binding_outputs() |
|
ort_session.run_with_iobinding(io_binding) |
|
|
|
|
|
ort_outputs = io_binding.get_outputs() |
|
output_data = ort_outputs[0].numpy() |
|
|
|
return (sampling_rate, output_data.squeeze()) |
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=tts_inference_io_binding, |
|
inputs=gr.Textbox(lines=3, placeholder="Enter text here..."), |
|
outputs=gr.Audio(type="numpy", label="Generated Speech"), |
|
title="Optimized MMS-TTS (English)", |
|
description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).", |
|
examples=[ |
|
["Hello, this is a demonstration."], |
|
["This uses ONNX Runtime and IO Binding."], |
|
["The quick brown fox jumps over the lazy dog."], |
|
["Try your own text!"] |
|
], |
|
cache_examples=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |