File size: 4,857 Bytes
66475d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import subprocess
import re
import multiprocessing
import json
import platform
from update_status import raw_dir_convert_to_path
lang_dict = {"EN(英文)": "_en", "ZH(中文)": "_zh", "JP(日语)": "_jp"}

def update_json(batch_size:int, log_interval:int, eval_interval:int,
                epochs:int, lr:float, keep_ckpts:int):
    with open("configs/config.json", "r", encoding='utf-8') as json_file:
        hps = json.load(json_file)
    hps["train"]["batch_size"] = batch_size
    hps["train"]["log_interval"] = log_interval
    hps["train"]["eval_interval"] = eval_interval
    hps["train"]["epochs"] = epochs
    hps["train"]["learning_rate"] = lr
    hps["train"]["keep_ckpts"] = keep_ckpts
    print("现在的[BS,LI,EI,epochs,lr,keep]: ", [batch_size, log_interval, eval_interval, epochs, lr, keep_ckpts])
    with open("configs/config.json", "w", encoding='utf-8') as json_file:
        json.dump(hps, json_file, indent=4)
    print("config.json文件已更新")


class SubprocessManager:
    def __init__(self):
        self.process = None

    def worker(self, command):
        try:
            result = subprocess.check_output(
                command,
                universal_newlines=True
            )
            result = re.sub(r'\x1B\[[0-?]*[ -/]*[@-~]', '', result)
            print(result)
        except subprocess.CalledProcessError as e:
            print(f"错误: {str(e)}")

    def start(self, command):
        if self.process:
            print("已有子进程正在运行,先终止它")
            self.terminate()

        if platform.system() == "Windows":
            cmd = ["cmd.exe", "/c"] + command
        else:
            cmd = command
        print(" ".join(cmd))
        self.process = multiprocessing.Process(target=self.worker, args=(cmd,))
        self.process.start()

    def terminate(self):
        if self.process:
            self.process.terminate()
            self.process.join()
            print("子进程已被终止")
            self.process = None


managers = [SubprocessManager() for _ in range(7)]

def do_transcribe(target_path, language, workers):
    target_path = raw_dir_convert_to_path(target_path, language)
    additional_args = ["-f", target_path, "-l", language, "-w", str(workers)]
    command = [r"python", "asr_transcript.py"]
    command.extend(additional_args)
    os.environ["SELECT_LANGUAGE"] = language
    managers[0].start(command)
    print("开始转写文本!请稍后~首先是验证下载模型,然后是转写~")

    return "开始转写文本!请稍后~首先是验证下载模型,然后是转写~"


def do_preprocess_text(target_path=""):
    os.makedirs("filelists", exist_ok=True)
    command = [r"python", "preprocess_text.py"]
    managers[1].start(command)
    print("开始生成训练集和验证集!请稍后~")
    return "开始生成训练集和验证集!请稍后~"

def do_resample(target_path=""):
    command = [r"python", "resample.py"]
    managers[2].start(command)
    print("开始音频重采样!请稍后~")
    return "开始音频重采样!请稍后~"

def do_bert_gen(num_processes, target_path=""):
    command = [r"python", "bert_gen.py"]
    command.extend(["--num_processes", str(num_processes)])
    managers[3].start(command)
    print("开始bert生成!请稍后~")
    return "开始bert生成!请稍后~"

def terminate_training():
    managers[4].terminate()
    print("终止训练!")
    return "终止训练!"

def do_training(model_folder:str, batch_size:int, log_interval:int, eval_interval:int,
                epochs:int, lr:float, keep_ckpts:int):
    update_json(batch_size, log_interval, eval_interval, epochs, lr, keep_ckpts)
    command = [r"python", "train_ms.py"]
    command.extend(["-m", model_folder,
                    "-c", './configs/config.json'])
    terminate_training()
    managers[4].start(command)
    print("开启训练成功!\n http://127.0.0.1:8000")
    return "开启训练成功!\n http://127.0.0.1:8000"

def terminate_webui():
    managers[5].terminate()
    print("关闭推理页面!")
    return "关闭推理页面!"

def do_inference_webui(model_path:str, config_path:str):
    command = [r"python", "webui.py"]
    command.extend(["-m", model_path, "-c", config_path])
    if not os.path.exists(model_path):
        return "找不到对应模型!请确保模型路径正确!"
    if not os.path.exists(config_path):
        return "找不到对应配置文件!请确保配置路径正确!"
    terminate_webui()
    managers[5].start(command)
    print("开启推理页面成功 \n http://127.0.0.1:7860")
    return "开启推理页面成功 \n http://127.0.0.1:7860"


def do_test(model_path:str=""):
    command = [r"python", "-m", "pip", "list"]
    managers[6].start(command)
    return "正在测试,请看控制台"