root
commited on
Commit
·
7c1f567
1
Parent(s):
3eb547e
add http client
Browse files
src/f5_tts/runtime/triton_trtllm/README.md
CHANGED
|
@@ -25,7 +25,10 @@ Inside docker container, we would follow the official guide of TensorRT-LLM to b
|
|
| 25 |
```sh
|
| 26 |
bash run.sh 0 4 F5TTS_Base
|
| 27 |
```
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
### Benchmark using Dataset
|
| 30 |
```sh
|
| 31 |
num_task=2
|
|
|
|
| 25 |
```sh
|
| 26 |
bash run.sh 0 4 F5TTS_Base
|
| 27 |
```
|
| 28 |
+
### HTTP Client
|
| 29 |
+
```sh
|
| 30 |
+
python3 client_http.py
|
| 31 |
+
```
|
| 32 |
### Benchmark using Dataset
|
| 33 |
```sh
|
| 34 |
num_task=2
|
src/f5_tts/runtime/triton_trtllm/client_http.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Redistribution and use in source and binary forms, with or without
|
| 4 |
+
# modification, are permitted provided that the following conditions
|
| 5 |
+
# are met:
|
| 6 |
+
# * Redistributions of source code must retain the above copyright
|
| 7 |
+
# notice, this list of conditions and the following disclaimer.
|
| 8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
| 9 |
+
# notice, this list of conditions and the following disclaimer in the
|
| 10 |
+
# documentation and/or other materials provided with the distribution.
|
| 11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
| 12 |
+
# contributors may be used to endorse or promote products derived
|
| 13 |
+
# from this software without specific prior written permission.
|
| 14 |
+
#
|
| 15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
| 16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
| 18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
| 19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
| 20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
| 21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
| 22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
| 23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 26 |
+
import requests
|
| 27 |
+
import soundfile as sf
|
| 28 |
+
import numpy as np
|
| 29 |
+
import argparse
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_args():
|
| 33 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 34 |
+
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--server-url",
|
| 37 |
+
type=str,
|
| 38 |
+
default="localhost:8000",
|
| 39 |
+
help="Address of the server",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--reference-audio",
|
| 44 |
+
type=str,
|
| 45 |
+
default="../../infer/examples/basic/basic_ref_en.wav",
|
| 46 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--reference-text",
|
| 51 |
+
type=str,
|
| 52 |
+
default="Some call me nature, others call me mother nature.",
|
| 53 |
+
help="",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--target-text",
|
| 58 |
+
type=str,
|
| 59 |
+
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
|
| 60 |
+
help="",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--model-name",
|
| 65 |
+
type=str,
|
| 66 |
+
default="f5_tts",
|
| 67 |
+
choices=["f5_tts", "spark_tts"],
|
| 68 |
+
help="triton model_repo module name to request",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--output-audio",
|
| 73 |
+
type=str,
|
| 74 |
+
default="output.wav",
|
| 75 |
+
help="Path to save the output audio",
|
| 76 |
+
)
|
| 77 |
+
return parser.parse_args()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def prepare_request(
|
| 81 |
+
samples,
|
| 82 |
+
reference_text,
|
| 83 |
+
target_text,
|
| 84 |
+
sample_rate=16000,
|
| 85 |
+
audio_save_dir: str = "./",
|
| 86 |
+
):
|
| 87 |
+
assert len(samples.shape) == 1, "samples should be 1D"
|
| 88 |
+
lengths = np.array([[len(samples)]], dtype=np.int32)
|
| 89 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
| 90 |
+
|
| 91 |
+
data = {
|
| 92 |
+
"inputs": [
|
| 93 |
+
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
|
| 94 |
+
{
|
| 95 |
+
"name": "reference_wav_len",
|
| 96 |
+
"shape": lengths.shape,
|
| 97 |
+
"datatype": "INT32",
|
| 98 |
+
"data": lengths.tolist(),
|
| 99 |
+
},
|
| 100 |
+
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
|
| 101 |
+
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
|
| 102 |
+
]
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return data
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def load_audio(wav_path, target_sample_rate=16000):
|
| 109 |
+
assert target_sample_rate == 16000, "hard coding in server"
|
| 110 |
+
if isinstance(wav_path, dict):
|
| 111 |
+
samples = wav_path["array"]
|
| 112 |
+
sample_rate = wav_path["sampling_rate"]
|
| 113 |
+
else:
|
| 114 |
+
samples, sample_rate = sf.read(wav_path)
|
| 115 |
+
if sample_rate != target_sample_rate:
|
| 116 |
+
from scipy.signal import resample
|
| 117 |
+
|
| 118 |
+
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
|
| 119 |
+
samples = resample(samples, num_samples)
|
| 120 |
+
return samples, target_sample_rate
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
args = get_args()
|
| 125 |
+
server_url = args.server_url
|
| 126 |
+
if not server_url.startswith(("http://", "https://")):
|
| 127 |
+
server_url = f"http://{server_url}"
|
| 128 |
+
|
| 129 |
+
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
| 130 |
+
samples, sr = load_audio(args.reference_audio)
|
| 131 |
+
assert sr == 16000, "sample rate hardcoded in server"
|
| 132 |
+
|
| 133 |
+
samples = np.array(samples, dtype=np.float32)
|
| 134 |
+
data = prepare_request(samples, args.reference_text, args.target_text)
|
| 135 |
+
|
| 136 |
+
rsp = requests.post(
|
| 137 |
+
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
|
| 138 |
+
)
|
| 139 |
+
result = rsp.json()
|
| 140 |
+
audio = result["outputs"][0]["data"]
|
| 141 |
+
audio = np.array(audio, dtype=np.float32)
|
| 142 |
+
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
src/f5_tts/runtime/triton_trtllm/run.sh
CHANGED
|
@@ -61,4 +61,12 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|
| 61 |
log_dir=./log_concurrent_tasks_${num_task}
|
| 62 |
rm -r $log_dir
|
| 63 |
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
fi
|
|
|
|
| 61 |
log_dir=./log_concurrent_tasks_${num_task}
|
| 62 |
rm -r $log_dir
|
| 63 |
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
| 67 |
+
echo "Testing http client"
|
| 68 |
+
audio=../../infer/examples/basic/basic_ref_en.wav
|
| 69 |
+
reference_text="Some call me nature, others call me mother nature."
|
| 70 |
+
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
| 71 |
+
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
| 72 |
fi
|