Spaces:
Build error
Build error
from pydantic import BaseModel, Field | |
import os | |
from pathlib import Path | |
from enum import Enum | |
from typing import Any | |
from synthesizer.hparams import hparams | |
from synthesizer.train import train as synt_train | |
# Constants | |
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models" | |
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" | |
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" | |
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models" | |
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" | |
# Pre-Load models | |
if os.path.isdir(SYN_MODELS_DIRT): | |
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt"))) | |
print("Loaded synthesizer models: " + str(len(synthesizers))) | |
else: | |
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.") | |
if os.path.isdir(ENC_MODELS_DIRT): | |
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) | |
print("Loaded encoders models: " + str(len(encoders))) | |
else: | |
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") | |
class Model(str, Enum): | |
DEFAULT = "default" | |
class Input(BaseModel): | |
model: Model = Field( | |
Model.DEFAULT, title="模型类型", | |
) | |
# datasets_root: str = Field( | |
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", | |
# format=True, | |
# example="..\\trainning_data\\" | |
# ) | |
input_root: str = Field( | |
..., alias="输入目录", description="预处理数据根目录", | |
format=True, | |
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer" | |
) | |
run_id: str = Field( | |
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练", | |
) | |
synthesizer: synthesizers = Field( | |
..., alias="已有合成模型", | |
description="选择语音合成模型文件." | |
) | |
gpu: bool = Field( | |
True, alias="GPU训练", description="选择“是”,则使用GPU训练", | |
) | |
verbose: bool = Field( | |
True, alias="打印详情", description="选择“是”,输出更多详情", | |
) | |
encoder: encoders = Field( | |
..., alias="语音编码模型", | |
description="选择语音编码模型文件." | |
) | |
save_every: int = Field( | |
1000, alias="更新间隔", description="每隔n步则更新一次模型", | |
) | |
backup_every: int = Field( | |
10000, alias="保存间隔", description="每隔n步则保存一次模型", | |
) | |
log_every: int = Field( | |
500, alias="打印间隔", description="每隔n步则打印一次训练统计", | |
) | |
class AudioEntity(BaseModel): | |
content: bytes | |
mel: Any | |
class Output(BaseModel): | |
__root__: int | |
def render_output_ui(self, streamlit_app) -> None: # type: ignore | |
"""Custom output UI. | |
If this method is implmeneted, it will be used instead of the default Output UI renderer. | |
""" | |
streamlit_app.subheader(f"Training started with code: {self.__root__}") | |
def train(input: Input) -> Output: | |
"""Train(训练)""" | |
print(">>> Start training ...") | |
force_restart = len(input.run_id) > 0 | |
if not force_restart: | |
input.run_id = Path(input.synthesizer.value).name.split('.')[0] | |
synt_train( | |
input.run_id, | |
input.input_root, | |
f"synthesizer{os.sep}saved_models", | |
input.save_every, | |
input.backup_every, | |
input.log_every, | |
force_restart, | |
hparams | |
) | |
return Output(__root__=0) |