Tg / app.py
Athspi's picture
Update app.py
9efb144 verified
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
from huggingface_hub import hf_hub_download
import os
# --- Configuration ---
repo_id = "Athspi/Gg"
onnx_filename = "mms_tts_eng.onnx"
sampling_rate = 16000
# --- Download ONNX Model ---
onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
print(f"ONNX model downloaded to (cache): {onnx_model_path}")
# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# --- ONNX Runtime Session Setup ---
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
try:
import psutil
num_physical_cores = psutil.cpu_count(logical=False)
except ImportError:
print("psutil not installed. Install with: pip install psutil")
num_physical_cores = 4
print(f"Using default: {num_physical_cores}")
session_options.intra_op_num_threads = num_physical_cores
session_options.inter_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(
onnx_model_path,
providers=['CPUExecutionProvider'],
sess_options=session_options,
)
# --- IO Binding Setup ---
io_binding = ort_session.io_binding()
input_meta = ort_session.get_inputs()[0]
output_meta = ort_session.get_outputs()[0]
dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
input_shape = tuple(dummy_input.shape)
input_type = dummy_input.numpy().dtype
input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
max_output_length = input_shape[1] * 10
output_shape = (1, 1, max_output_length)
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
# Initial binding
io_binding.bind_input(
name=input_meta.name, device_type="cpu", device_id=0,
element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
)
io_binding.bind_output(
name=output_meta.name, device_type="cpu", device_id=0,
element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
)
# --- Inference Function ---
def tts_inference_io_binding(text: str):
"""TTS inference with IO Binding."""
global input_tensor, output_tensor, io_binding
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs.input_ids.to(torch.long)
current_input_shape = tuple(input_ids.shape)
# Resize and re-bind input if necessary
if current_input_shape[1] > input_tensor.shape[1]:
input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
io_binding.bind_input(
name=input_meta.name, device_type="cpu", device_id=0,
element_type=input_type, shape=current_input_shape,
buffer_ptr=input_tensor.data_ptr(),
)
# Copy input data to the pre-allocated tensor
input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
# Resize and re-bind *output* if necessary
required_output_length = current_input_shape[1] * 10 # Estimate
if required_output_length > output_tensor.shape[2]:
output_shape = (1, 1, required_output_length)
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
io_binding.bind_output( # Re-bind output
name=output_meta.name, device_type="cpu", device_id=0,
element_type=np.float32, shape=output_shape,
buffer_ptr=output_tensor.data_ptr(),
)
# Clear outputs *before* running inference, *after* (re)binding
io_binding.clear_binding_outputs()
ort_session.run_with_iobinding(io_binding) # Run inference
# The output data is now *already* in output_tensor, so we just get it
ort_outputs = io_binding.get_outputs() # Get a list with the output information.
output_data = ort_outputs[0].numpy() # Get the data as a NumPy array
return (sampling_rate, output_data.squeeze())
# --- Gradio Interface ---
iface = gr.Interface(
fn=tts_inference_io_binding,
inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
outputs=gr.Audio(type="numpy", label="Generated Speech"),
title="Optimized MMS-TTS (English)",
description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).",
examples=[
["Hello, this is a demonstration."],
["This uses ONNX Runtime and IO Binding."],
["The quick brown fox jumps over the lazy dog."],
["Try your own text!"]
],
cache_examples=False,
)
if __name__ == "__main__":
iface.launch()