|
import logging |
|
logging.getLogger('numba').setLevel(logging.WARNING) |
|
logging.getLogger('matplotlib').setLevel(logging.WARNING) |
|
logging.getLogger('urllib3').setLevel(logging.WARNING) |
|
from text import text_to_sequence |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import torch |
|
import json |
|
import commons |
|
import utils |
|
import sys |
|
import pathlib |
|
import onnxruntime as ort |
|
import gradio as gr |
|
import argparse |
|
import time |
|
import os |
|
import io |
|
from scipy.io.wavfile import write |
|
from flask import Flask, request |
|
from threading import Thread |
|
import openai |
|
import requests |
|
class VitsGradio: |
|
def __init__(self): |
|
self.lan = ["中文","日文","自动"] |
|
self.chatapi = ["gpt-3.5-turbo","gpt3"] |
|
self.modelPaths = [] |
|
for root,dirs,files in os.walk("checkpoints"): |
|
for dir in dirs: |
|
self.modelPaths.append(dir) |
|
with gr.Blocks() as self.Vits: |
|
with gr.Tab("调试用"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
self.text = gr.TextArea(label="Text", value="你好") |
|
with gr.Accordion(label="测试api", open=False): |
|
self.local_chat1 = gr.Checkbox(value=False, label="使用网址+文本进行模拟") |
|
self.url_input = gr.TextArea(label="键入测试", value="http://127.0.0.1:8080/chat?Text=") |
|
butto = gr.Button("模拟前端抓取语音文件") |
|
btnVC = gr.Button("测试tts+对话程序") |
|
with gr.Column(): |
|
output2 = gr.TextArea(label="回复") |
|
output1 = gr.Audio(label="采样率22050") |
|
output3 = gr.outputs.File(label="44100hz: output.wav") |
|
butto.click(self.Simul, inputs=[self.text, self.url_input], outputs=[output2,output3]) |
|
btnVC.click(self.tts_fn, inputs=[self.text], outputs=[output1,output2]) |
|
with gr.Tab("控制面板"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
self.api_input1 = gr.TextArea(label="输入api-key或本地存储说话模型的路径", value="https://platform.openai.com/account/api-keys") |
|
with gr.Accordion(label="chatbot选择", open=False): |
|
self.api_input2 = gr.Checkbox(value=True, label="采用gpt3.5") |
|
self.local_chat1 = gr.Checkbox(value=False, label="启动本地chatbot") |
|
self.local_chat2 = gr.Checkbox(value=True, label="是否量化") |
|
res = gr.TextArea() |
|
Botselection = gr.Button("完成chatbot设定") |
|
Botselection.click(self.check_bot, inputs=[self.api_input1,self.api_input2,self.local_chat1,self.local_chat2], outputs = [res]) |
|
self.input1 = gr.Dropdown(label = "模型", choices = self.modelPaths, value = self.modelPaths[0], type = "value") |
|
self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True) |
|
with gr.Column(): |
|
btnVC = gr.Button("完成vits TTS端设定") |
|
self.input3 = gr.Dropdown(label="Speaker", choices=list(range(101)), value=0, interactive=True) |
|
self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.267) |
|
self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.7) |
|
self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1) |
|
statusa = gr.TextArea() |
|
btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa]) |
|
|
|
def Simul(self,text,url_input): |
|
web = url_input + text |
|
res = requests.get(web) |
|
music = res.content |
|
with open('output.wav', 'wb') as code: |
|
code.write(music) |
|
file_path = "output.wav" |
|
return web,file_path |
|
|
|
|
|
def chatgpt(self,text): |
|
self.messages.append({"role": "user", "content": text},) |
|
chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages= self.messages) |
|
reply = chat.choices[0].message.content |
|
return reply |
|
|
|
def ChATGLM(self,text): |
|
if text == 'clear': |
|
self.history = [] |
|
response, new_history = self.model.chat(self.tokenizer, text, self.history) |
|
response = response.replace(" ",'').replace("\n",'.') |
|
self.history = new_history |
|
return response |
|
|
|
def gpt3_chat(self,text): |
|
call_name = "Waifu" |
|
openai.api_key = args.key |
|
identity = "" |
|
start_sequence = '\n'+str(call_name)+':' |
|
restart_sequence = "\nYou: " |
|
if 1 == 1: |
|
prompt0 = text |
|
if text == 'quit': |
|
return prompt0 |
|
prompt = identity + prompt0 + start_sequence |
|
response = openai.Completion.create( |
|
model="text-davinci-003", |
|
prompt=prompt, |
|
temperature=0.5, |
|
max_tokens=1000, |
|
top_p=1.0, |
|
frequency_penalty=0.5, |
|
presence_penalty=0.0, |
|
stop=["\nYou:"] |
|
) |
|
return response['choices'][0]['text'].strip() |
|
|
|
def check_bot(self,api_input1,api_input2,local_chat1,local_chat2): |
|
if local_chat1: |
|
from transformers import AutoTokenizer, AutoModel |
|
self.tokenizer = AutoTokenizer.from_pretrained(api_input1, trust_remote_code=True) |
|
if local_chat2: |
|
self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True).half().quantize(4).cuda() |
|
else: |
|
self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True) |
|
self.history = [] |
|
else: |
|
self.messages = [] |
|
openai.api_key = api_input1 |
|
return "Finished" |
|
|
|
def is_japanese(self,string): |
|
for ch in string: |
|
if ord(ch) > 0x3040 and ord(ch) < 0x30FF: |
|
return True |
|
return False |
|
|
|
def is_english(self,string): |
|
import re |
|
pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$') |
|
if pattern.fullmatch(string): |
|
return True |
|
else: |
|
return False |
|
|
|
def get_symbols_from_json(self,path): |
|
assert os.path.isfile(path) |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
return data['symbols'] |
|
|
|
def sle(self,language,text): |
|
text = text.replace('\n','。').replace(' ',',') |
|
if language == "中文": |
|
tts_input1 = "[ZH]" + text + "[ZH]" |
|
return tts_input1 |
|
elif language == "自动": |
|
tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]" |
|
return tts_input1 |
|
elif language == "日文": |
|
tts_input1 = "[JA]" + text + "[JA]" |
|
return tts_input1 |
|
|
|
def get_text(self,text,hps_ms): |
|
text_norm = text_to_sequence(text,hps_ms.data.text_cleaners) |
|
if hps_ms.data.add_blank: |
|
text_norm = commons.intersperse(text_norm, 0) |
|
text_norm = torch.LongTensor(text_norm) |
|
return text_norm |
|
|
|
def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ): |
|
self.symbols = self.get_symbols_from_json(f"checkpoints/{path}/config.json") |
|
self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") |
|
phone_dict = { |
|
symbol: i for i, symbol in enumerate(self.symbols) |
|
} |
|
self.ort_sess = ort.InferenceSession(f"checkpoints/{path}/model.onnx") |
|
self.language = input2 |
|
self.speaker_id = input3 |
|
self.n_scale = n_scale |
|
self.n_scale_w = n_scale_w |
|
self.l_scale = l_scale |
|
print(self.language,self.speaker_id,self.n_scale) |
|
return 'success' |
|
|
|
def tts_fn(self,text): |
|
if self.local_chat1: |
|
text = self.chatgpt(text) |
|
elif self.api_input2: |
|
text = self.ChATGLM(text) |
|
else: |
|
text = self.gpt3_chat(text) |
|
print(text) |
|
text =self.sle(self.language,text) |
|
seq = text_to_sequence(text, cleaner_names=self.hps.data.text_cleaners) |
|
if self.hps.data.add_blank: |
|
seq = commons.intersperse(seq, 0) |
|
with torch.no_grad(): |
|
x = np.array([seq], dtype=np.int64) |
|
x_len = np.array([x.shape[1]], dtype=np.int64) |
|
sid = np.array([self.speaker_id], dtype=np.int64) |
|
scales = np.array([self.n_scale, self.n_scale_w, self.l_scale], dtype=np.float32) |
|
scales.resize(1, 3) |
|
ort_inputs = { |
|
'input': x, |
|
'input_lengths': x_len, |
|
'scales': scales, |
|
'sid': sid |
|
} |
|
t1 = time.time() |
|
audio = np.squeeze(self.ort_sess.run(None, ort_inputs)) |
|
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6 |
|
audio = np.clip(audio, -32767.0, 32767.0) |
|
t2 = time.time() |
|
spending_time = "推理时间:"+str(t2-t1)+"s" |
|
print(spending_time) |
|
bytes_wav = bytes() |
|
byte_io = io.BytesIO(bytes_wav) |
|
wavfile.write('moe/temp1.wav',self.hps.data.sampling_rate, audio.astype(np.int16)) |
|
cmd = 'ffmpeg -y -i ' + 'moe/temp1.wav' + ' -ar 44100 ' + 'moe/temp2.wav' |
|
os.system(cmd) |
|
return (self.hps.data.sampling_rate, audio),text.replace('[JA]','').replace('[ZH]','') |
|
|
|
app = Flask(__name__) |
|
print("开始部署") |
|
grVits = VitsGradio() |
|
|
|
@app.route('/chat') |
|
def text_api(): |
|
message = request.args.get('Text','') |
|
audio,text = grVits.tts_fn(message) |
|
text = text.replace('[JA]','').replace('[ZH]','') |
|
with open('moe/temp2.wav','rb') as bit: |
|
wav_bytes = bit.read() |
|
headers = { |
|
'Content-Type': 'audio/wav', |
|
'Text': text.encode('utf-8')} |
|
return wav_bytes, 200, headers |
|
|
|
def gradio_interface(): |
|
return grVits.Vits.launch() |
|
|
|
if __name__ == '__main__': |
|
api_thread = Thread(target=app.run, args=("0.0.0.0", 8080)) |
|
gradio_thread = Thread(target=gradio_interface) |
|
api_thread.start() |
|
gradio_thread.start() |