lovelive-ShojoKageki-vits / inference_onnx.py
Mahiruoshi's picture
Upload 73 files
4de73fc
raw
history blame
5.36 kB
# Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from text import text_to_sequence
import numpy as np
from scipy.io import wavfile
import torch
import json
import commons
import utils
import sys
import pathlib
try:
import onnxruntime as ort
except ImportError:
print('Please install onnxruntime!')
sys.exit(1)
def to_numpy(tensor: torch.Tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad \
else tensor.detach().numpy()
def get_args():
parser = argparse.ArgumentParser(description='inference')
parser.add_argument('--onnx_model', required=True, help='onnx model')
parser.add_argument('--cfg', required=True, help='config file')
parser.add_argument('--outdir', default="onnx_output",
help='ouput directory')
# parser.add_argument('--phone_table',
# required=True,
# help='input phone dict')
# parser.add_argument('--speaker_table', default=None, help='speaker table')
parser.add_argument('--test_file', required=True, help='test file')
args = parser.parse_args()
return args
def get_symbols_from_json(path):
import os
assert os.path.isfile(path)
with open(path, 'r') as f:
data = json.load(f)
return data['symbols']
def main():
args = get_args()
print(args)
if not pathlib.Path(args.outdir).exists():
pathlib.Path(args.outdir).mkdir(exist_ok=True, parents=True)
# phones =
symbols = get_symbols_from_json(args.cfg)
phone_dict = {
symbol: i for i, symbol in enumerate(symbols)
}
# speaker_dict = {}
# if args.speaker_table is not None:
# with open(args.speaker_table) as p_f:
# for line in p_f:
# arr = line.strip().split()
# assert len(arr) == 2
# speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)
ort_sess = ort.InferenceSession(args.onnx_model)
with open(args.test_file) as fin:
for line in fin:
arr = line.strip().split("|")
audio_path = arr[0]
# TODO: 控制说话人编号
sid = 3
text = '[ZH]你好,重庆市位于四川省东边[ZH]'
# else:
# sid = speaker_dict[arr[1]]
# text = arr[2]
seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners
)
if hps.data.add_blank:
seq = commons.intersperse(seq, 0)
# if hps.data.add_blank:
# seq = commons.intersperse(seq, 0)
with torch.no_grad():
# x = torch.LongTensor([seq])
# x_len = torch.IntTensor([x.size(1)]).long()
# sid = torch.LongTensor([sid]).long()
# scales = torch.FloatTensor([0.667, 1.0, 1])
# # make triton dynamic shape happy
# scales = scales.unsqueeze(0)
# use numpy to replace torch
x = np.array([seq], dtype=np.int64)
x_len = np.array([x.shape[1]], dtype=np.int64)
sid = np.array([sid], dtype=np.int64)
# noise(可用于控制感情等变化程度) lenth(可用于控制整体语速) noisew(控制音素发音长度变化程度)
# 参考 https://github.com/gbxh/genshinTTS
scales = np.array([0.667, 0.8, 1], dtype=np.float32)
# scales = scales[np.newaxis, :]
# scales.reshape(1, -1)
scales.resize(1, 3)
ort_inputs = {
'input': x,
'input_lengths': x_len,
'scales': scales,
'sid': sid
}
# ort_inputs = {
# 'input': to_numpy(x),
# 'input_lengths': to_numpy(x_len),
# 'scales': to_numpy(scales),
# 'sid': to_numpy(sid)
# }
import time
# start_time = time.time()
start_time = time.perf_counter()
audio = np.squeeze(ort_sess.run(None, ort_inputs))
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
audio = np.clip(audio, -32767.0, 32767.0)
end_time = time.perf_counter()
# end_time = time.time()
print("infer time cost: ", end_time - start_time, "s")
wavfile.write(args.outdir + "/" + audio_path.split("/")[-1],
hps.data.sampling_rate, audio.astype(np.int16))
if __name__ == '__main__':
main()