Spaces:
Runtime error
Runtime error
# 训练部分实现 | |
import os | |
import shutil | |
import uuid | |
import gradio as gr | |
from modelscope.models.audio.tts import SambertHifigan | |
from modelscope.pipelines import pipeline | |
from modelscope.utils.constant import Tasks | |
from modelscope.metainfo import Trainers | |
from modelscope.trainers import build_trainer | |
from modelscope.utils.audio.audio_utils import TtsTrainType | |
from modelscope.hub.utils.utils import get_cache_dir | |
from utils_base import ensure_empty_dir, get_last_file, models_dir, get_model_list | |
import consts | |
# 绝对路径获取方法 | |
curPath = os.path.dirname(os.path.abspath(__file__)) | |
def getAbsPath (relativePath): | |
joinPath = os.path.join(curPath, relativePath) | |
return os.path.normpath( | |
os.path.abspath(joinPath) | |
) | |
# 模型训练 --------------------------------------------------------- | |
# name - 训练结果(小模型)命名 | |
# steps - 训练步数 | |
# train_dataset_zip - 数据集zip包路径 | |
def train(name, steps, train_dataset_name): | |
# 创建临时目录用于放置 训练结果 | |
work_dir = getAbsPath(f'./temp/work-{ uuid.uuid4() }') | |
ensure_empty_dir(work_dir) | |
# 数据集目录 | |
train_dataset = getAbsPath(f'./datasets/{ train_dataset_name }') | |
# 进行训练 | |
trainer = build_trainer( | |
Trainers.speech_kantts_trainer, | |
default_args=dict( | |
# 指定要finetune的 模型/版本 | |
model = consts.base_model_id, | |
model_revision = consts.base_model_version, | |
work_dir = work_dir, # 指定临时工作目录 | |
train_dataset = train_dataset, # 数据集目录 | |
# 训练参数 | |
train_type = { | |
TtsTrainType.TRAIN_TYPE_SAMBERT: { # 配置训练AM(sambert)模型 | |
'train_steps': steps + 1, # 训练多少个step | |
'save_interval_steps': 20, # 每训练多少个step保存一次checkpoint | |
'log_interval': 10 # 每训练多少个step打印一次训练日志 | |
} | |
} | |
) | |
) | |
trainer.train() | |
# 挑选需要的文件到结果目录 | |
target_dir = os.path.join(models_dir, name) | |
ensure_empty_dir(target_dir) | |
shutil.os.makedirs(os.path.join(target_dir, 'tmp_am', 'ckpt')) | |
shutil.os.makedirs(os.path.join(target_dir, 'data', 'se')) | |
shutil.copy( | |
get_last_file(os.path.join(work_dir, 'tmp_am', 'ckpt')), | |
os.path.join(target_dir, 'tmp_am', 'ckpt') | |
) | |
shutil.copy( | |
os.path.join(work_dir, 'tmp_am', 'config.yaml'), | |
os.path.join(target_dir, 'tmp_am'), | |
) | |
shutil.copy( | |
os.path.join(work_dir, 'data', 'audio_config.yaml'), | |
os.path.join(target_dir, 'data'), | |
) | |
shutil.copy( | |
os.path.join(work_dir, 'data', 'se', 'se.npy'), | |
os.path.join(target_dir, 'data', 'se'), | |
) | |
# 清理文件 | |
shutil.rmtree(work_dir) | |
shutil.rmtree(train_dataset) | |
# 返回结果 | |
return '训练完成', gr.update(choices=get_model_list()) | |
# 模型推理 --------------------------------------------------------- | |
# name - 使用的小模型名称 | |
# txt - 需要合成音频的文字 | |
def infer(name, txt): | |
try: | |
base_model_path = os.path.join(get_cache_dir(), consts.base_model_id) | |
model_path = os.path.join(models_dir, name) | |
custom_infer_abs = { | |
'voice_name': 'F7', | |
# 小模型部分 | |
'am_ckpt': os.path.join(model_path, 'tmp_am', 'ckpt'), | |
'am_config': os.path.join(model_path, 'tmp_am', 'config.yaml'), | |
'audio_config': os.path.join(model_path, 'data', 'audio_config.yaml'), | |
'se_file': os.path.join(model_path, 'data', 'se', 'se.npy'), | |
# 基础模型部分 | |
'voc_ckpt': os.path.join( | |
base_model_path, 'basemodel_16k', 'hifigan', 'ckpt' | |
), | |
'voc_config': os.path.join( | |
base_model_path, 'basemodel_16k', 'hifigan', 'config.yaml' | |
) | |
} | |
model = SambertHifigan( | |
base_model_path, | |
**{ 'custom_ckpt': custom_infer_abs } | |
) | |
inference = pipeline(task=Tasks.text_to_speech, model=model) | |
output = inference(input=txt) | |
output_path = f'/tmp/{ uuid.uuid4() }.wav' | |
with open(output_path, mode='bx') as f: | |
f.write(output['output_wav']) | |
return output_path | |
except Exception: | |
return False | |
# 删除模型 --------------------------------------------------------- | |
# name - 删除的小模型名称 | |
def delete_model(name): | |
try: | |
if not name: | |
return gr.update(choices=get_model_list()) | |
target_dir = os.path.join(models_dir, name) | |
shutil.rmtree(target_dir) | |
return gr.update(choices=get_model_list(), value=None) | |
except Exception: | |
return gr.update(choices=get_model_list(), value=None) | |