bert-vits2-maolei / extern_subprocess.py
RUSH-miaomi's picture
Upload 22 files
94aaff9
raw
history blame
4.86 kB
import os
import subprocess
import re
import multiprocessing
import json
import platform
from update_status import raw_dir_convert_to_path
lang_dict = {"EN(英文)": "_en", "ZH(中文)": "_zh", "JP(日语)": "_jp"}
def update_json(batch_size:int, log_interval:int, eval_interval:int,
epochs:int, lr:float, keep_ckpts:int):
with open("configs/config.json", "r", encoding='utf-8') as json_file:
hps = json.load(json_file)
hps["train"]["batch_size"] = batch_size
hps["train"]["log_interval"] = log_interval
hps["train"]["eval_interval"] = eval_interval
hps["train"]["epochs"] = epochs
hps["train"]["learning_rate"] = lr
hps["train"]["keep_ckpts"] = keep_ckpts
print("现在的[BS,LI,EI,epochs,lr,keep]: ", [batch_size, log_interval, eval_interval, epochs, lr, keep_ckpts])
with open("configs/config.json", "w", encoding='utf-8') as json_file:
json.dump(hps, json_file, indent=4)
print("config.json文件已更新")
class SubprocessManager:
def __init__(self):
self.process = None
def worker(self, command):
try:
result = subprocess.check_output(
command,
universal_newlines=True
)
result = re.sub(r'\x1B\[[0-?]*[ -/]*[@-~]', '', result)
print(result)
except subprocess.CalledProcessError as e:
print(f"错误: {str(e)}")
def start(self, command):
if self.process:
print("已有子进程正在运行,先终止它")
self.terminate()
if platform.system() == "Windows":
cmd = ["cmd.exe", "/c"] + command
else:
cmd = command
print(" ".join(cmd))
self.process = multiprocessing.Process(target=self.worker, args=(cmd,))
self.process.start()
def terminate(self):
if self.process:
self.process.terminate()
self.process.join()
print("子进程已被终止")
self.process = None
managers = [SubprocessManager() for _ in range(7)]
def do_transcribe(target_path, language, workers):
target_path = raw_dir_convert_to_path(target_path, language)
additional_args = ["-f", target_path, "-l", language, "-w", str(workers)]
command = [r"python", "asr_transcript.py"]
command.extend(additional_args)
os.environ["SELECT_LANGUAGE"] = language
managers[0].start(command)
print("开始转写文本!请稍后~首先是验证下载模型,然后是转写~")
return "开始转写文本!请稍后~首先是验证下载模型,然后是转写~"
def do_preprocess_text(target_path=""):
os.makedirs("filelists", exist_ok=True)
command = [r"python", "preprocess_text.py"]
managers[1].start(command)
print("开始生成训练集和验证集!请稍后~")
return "开始生成训练集和验证集!请稍后~"
def do_resample(target_path=""):
command = [r"python", "resample.py"]
managers[2].start(command)
print("开始音频重采样!请稍后~")
return "开始音频重采样!请稍后~"
def do_bert_gen(num_processes, target_path=""):
command = [r"python", "bert_gen.py"]
command.extend(["--num_processes", str(num_processes)])
managers[3].start(command)
print("开始bert生成!请稍后~")
return "开始bert生成!请稍后~"
def terminate_training():
managers[4].terminate()
print("终止训练!")
return "终止训练!"
def do_training(model_folder:str, batch_size:int, log_interval:int, eval_interval:int,
epochs:int, lr:float, keep_ckpts:int):
update_json(batch_size, log_interval, eval_interval, epochs, lr, keep_ckpts)
command = [r"python", "train_ms.py"]
command.extend(["-m", model_folder,
"-c", './configs/config.json'])
terminate_training()
managers[4].start(command)
print("开启训练成功!\n http://127.0.0.1:8000")
return "开启训练成功!\n http://127.0.0.1:8000"
def terminate_webui():
managers[5].terminate()
print("关闭推理页面!")
return "关闭推理页面!"
def do_inference_webui(model_path:str, config_path:str):
command = [r"python", "webui.py"]
command.extend(["-m", model_path, "-c", config_path])
if not os.path.exists(model_path):
return "找不到对应模型!请确保模型路径正确!"
if not os.path.exists(config_path):
return "找不到对应配置文件!请确保配置路径正确!"
terminate_webui()
managers[5].start(command)
print("开启推理页面成功 \n http://127.0.0.1:7860")
return "开启推理页面成功 \n http://127.0.0.1:7860"
def do_test(model_path:str=""):
command = [r"python", "-m", "pip", "list"]
managers[6].start(command)
return "正在测试,请看控制台"