|
|
|
import logging |
|
import tempfile |
|
|
|
from inference.infer_tool import Svc |
|
from typing import * |
|
import api.base |
|
import os |
|
import io |
|
import wave |
|
import numpy as np |
|
from service.tool import audio_normalize, read_wav_file_to_numpy_array |
|
from utils import get_hparams_from_file |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
_svc: Optional[Svc] = None |
|
_model_paths: Optional[List] = None |
|
|
|
|
|
def init(): |
|
global _svc, _model_paths |
|
_svc = Svc() |
|
_model_paths = [] |
|
|
|
|
|
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.curdir)) |
|
|
|
|
|
checkpoints_dir = os.path.join(parent_dir, "checkpoints") |
|
|
|
logger.debug(f"CkPoints Dir: {checkpoints_dir}") |
|
|
|
for root, dirs, files in os.walk(checkpoints_dir): |
|
for dir in dirs: |
|
_model_paths.append(dir) |
|
|
|
|
|
|
|
class ModelListHandler(api.base.ApiHandler): |
|
async def get(self): |
|
self.write({ |
|
"code": 200, |
|
"msg": "ok", |
|
"data": _model_paths |
|
}) |
|
|
|
|
|
|
|
class SwitchHandler(api.base.ApiHandler): |
|
async def post(self): |
|
model_name = self.get_argument("model", "") |
|
mode = self.get_argument("mode", "single") |
|
device = self.get_argument("device", "cuda") |
|
|
|
if model_name == "": |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "未选择模型!", |
|
"data": None |
|
}) |
|
return |
|
|
|
if mode not in ("single", "batch"): |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "运行模式选择错误!", |
|
"data": None |
|
}) |
|
return |
|
|
|
if device not in ("cpu", "cuda"): |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "设备选择错误!", |
|
"data": None |
|
}) |
|
return |
|
|
|
logger.debug(f"modelname: {model_name}\n" |
|
f"mode: {mode}\n" |
|
f"device: {device}\n") |
|
try: |
|
_svc.set_device(device=device) |
|
logger.debug(f"Device set.") |
|
_svc.load_checkpoint(path=model_name) |
|
logger.debug(f"Model set.") |
|
except Exception as e: |
|
logger.exception(e) |
|
self.set_status(500) |
|
self.write({ |
|
"code": 500, |
|
"msg": "system_error", |
|
"data": None |
|
}) |
|
return |
|
|
|
self.write({ |
|
"code": 200, |
|
"msg": "ok", |
|
"data": { |
|
"mode": mode |
|
} |
|
}) |
|
|
|
|
|
|
|
class SingleInferenceHandler(api.base.ApiHandler): |
|
async def post(self): |
|
try: |
|
from scipy.io import wavfile |
|
|
|
dsid = self.get_argument("dsid", "") |
|
tran = self.get_argument("tran", "0") |
|
th = self.get_argument("th", "-40.0") |
|
ns = self.get_argument("ns", "0.4") |
|
audiofile_dict = self.request.files.get("srcaudio", []) |
|
|
|
if not audiofile_dict: |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "未上传文件!", |
|
"data": None |
|
}) |
|
return |
|
|
|
if dsid == "": |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "未选择模型!", |
|
"data": None |
|
}) |
|
return |
|
|
|
audiofile = audiofile_dict[0] |
|
audio_filename = audiofile['filename'] |
|
audio_filebody = audiofile['body'] |
|
audio_fileext = os.path.splitext(audio_filename)[-1].lower() |
|
|
|
with tempfile.NamedTemporaryFile(suffix=audio_fileext, delete=False) as temp_file: |
|
temp_file.write(audio_filebody) |
|
temp_file.close() |
|
|
|
converted_file = await audio_normalize(full_filename=audio_filename, file_data=audio_filebody) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampling_rate, audio_array = read_wav_file_to_numpy_array(converted_file) |
|
os.remove(converted_file) |
|
|
|
scraudio = (sampling_rate, audio_array) |
|
|
|
logger.debug(f"read file {audio_filename}\n" |
|
f"sampling rate: {sampling_rate}") |
|
|
|
tran = float(tran) |
|
th = float(th) |
|
ns = float(ns) |
|
|
|
hparams = get_hparams_from_file(f"checkpoints/{dsid}/config.json") |
|
spk = hparams.spk |
|
real_dsid = "" |
|
for k, v in spk.items(): |
|
if v == 0: |
|
real_dsid = k |
|
logger.debug(f"read dsid is: {real_dsid}") |
|
|
|
output_audio_sr, output_audio_array = _svc.inference(srcaudio=scraudio, |
|
chara=real_dsid, |
|
tran=tran, |
|
slice_db=th, |
|
ns=ns) |
|
|
|
logger.debug(f"svc for {audio_filename} succeed. \n" |
|
f"audio data type: {type(output_audio_array)}\n" |
|
f"audio data sr: {output_audio_sr}") |
|
|
|
logger.debug(f"start output data.") |
|
|
|
|
|
with io.BytesIO() as wav_file: |
|
wavfile.write(wav_file, sampling_rate, output_audio_array) |
|
wav_data = wav_file.getvalue() |
|
|
|
|
|
self.set_header('Content-Type', 'audio/wav') |
|
self.set_header('Content-Disposition', f'attachment; filename="svc_output.wav"') |
|
self.write(wav_data) |
|
await self.flush() |
|
logger.debug(f"response completed.") |
|
except Exception as e: |
|
logger.exception(e) |
|
self.set_status(500) |
|
self.write({ |
|
"code": 500, |
|
"msg": "system_error", |
|
"data": None |
|
}) |
|
return |
|
|
|
|
|
|
|
class BatchInferenceHandler(api.base.ApiHandler): |
|
async def post(self): |
|
try: |
|
from zipfile import ZipFile |
|
from scipy.io import wavfile |
|
import uuid |
|
|
|
dsid = self.get_argument("dsid", "") |
|
tran = self.get_argument("tran", "0") |
|
th = self.get_argument("th", "-40.0") |
|
ns = self.get_argument("ns", "0.4") |
|
audiofile_dict = self.request.files.get("srcaudio", []) |
|
|
|
logger.debug(len(self.request.files)) |
|
|
|
if not audiofile_dict: |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "未上传文件!", |
|
"data": None |
|
}) |
|
return |
|
|
|
if dsid == "": |
|
self.set_status(400) |
|
self.write({ |
|
"code": 400, |
|
"msg": "未选择模型!", |
|
"data": None |
|
}) |
|
return |
|
|
|
temp_dir_name = "temp" |
|
|
|
|
|
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.curdir)) |
|
|
|
|
|
temp_dir = os.path.join(parent_dir, temp_dir_name) |
|
|
|
logger.debug(f"TempDir: {temp_dir}") |
|
|
|
if not os.path.exists(temp_dir): |
|
os.mkdir(temp_dir) |
|
|
|
tmp_workdir_name = f"{temp_dir}/batch_{uuid.uuid4()}" |
|
if not os.path.exists(tmp_workdir_name): |
|
os.mkdir(tmp_workdir_name) |
|
|
|
output_files = [] |
|
|
|
tran = float(tran) |
|
th = float(th) |
|
ns = float(ns) |
|
|
|
hparams = get_hparams_from_file(f"checkpoints/{dsid}/config.json") |
|
spk = hparams.spk |
|
real_dsid = "" |
|
for k, v in spk.items(): |
|
if v == 0: |
|
real_dsid = k |
|
logger.debug(f"read dsid is: {real_dsid}") |
|
|
|
for idx, file in enumerate(audiofile_dict): |
|
audio_filename = file["filename"] |
|
audio_filebody = file["body"] |
|
filename = os.path.basename(audio_filename) |
|
audio_fileext = os.path.splitext(audio_filename)[-1].lower() |
|
|
|
with tempfile.NamedTemporaryFile(suffix=audio_fileext, delete=False) as temp_file: |
|
temp_file.write(audio_filebody) |
|
temp_file.close() |
|
|
|
converted_file = await audio_normalize(full_filename=audio_filename, file_data=audio_filebody) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampling_rate, audio_array = read_wav_file_to_numpy_array(converted_file) |
|
os.remove(converted_file) |
|
|
|
scraudio = (sampling_rate, audio_array) |
|
|
|
print(f"{idx}, {len(audio_filebody)}, {filename}") |
|
|
|
output_sampling_rate, output_audio = _svc.inference(scraudio, chara=real_dsid, tran=tran, |
|
slice_db=th, ns=ns) |
|
new_filepath = f"{tmp_workdir_name}/{filename}" |
|
wavfile.write(filename=new_filepath, rate=output_sampling_rate, data=output_audio) |
|
output_files.append(new_filepath) |
|
|
|
zipfilename = f"{tmp_workdir_name}/output.zip" |
|
with ZipFile(zipfilename, "w") as zip_obj: |
|
for idx, filepath in enumerate(output_files): |
|
zip_obj.write(filepath, os.path.basename(filepath)) |
|
|
|
|
|
|
|
logger.debug(f"start output data.") |
|
|
|
self.set_header("Content-Type", "application/zip") |
|
self.set_header("Content-Disposition", "attachment; filename=output.zip") |
|
with open(zipfilename, "rb") as file: |
|
self.write(file.read()) |
|
await self.flush() |
|
logger.debug(f"response completed.") |
|
except Exception as e: |
|
logger.exception(e) |
|
self.set_status(500) |
|
self.write({ |
|
"code": 500, |
|
"msg": "system_error", |
|
"data": None |
|
}) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
init() |
|
print(_model_paths) |
|
|