liumaolin commited on
Commit ·
f458b69
1
Parent(s): 8178cb9
Use `tempfile` to handle temporary configuration files and ensure cleanup in training stages
Browse files
training_pipeline/stages/training.py
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
|
| 9 |
import json
|
| 10 |
import os
|
|
|
|
| 11 |
from typing import Dict, Any, Generator
|
| 12 |
|
| 13 |
import yaml
|
|
@@ -74,25 +75,27 @@ class SoVITSTrainStage(BaseStage):
|
|
| 74 |
data["version"] = version_str
|
| 75 |
|
| 76 |
# 写入临时配置
|
| 77 |
-
|
| 78 |
-
os.makedirs(tmp_dir, exist_ok=True)
|
| 79 |
-
tmp_config_path = f"{tmp_dir}/tmp_s2.json"
|
| 80 |
-
with open(tmp_config_path, "w") as f:
|
| 81 |
f.write(json.dumps(data))
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
class GPTTrainStage(BaseStage):
|
|
@@ -142,23 +145,25 @@ class GPTTrainStage(BaseStage):
|
|
| 142 |
data["output_dir"] = f"{s1_dir}/logs_s1_{version_str}"
|
| 143 |
|
| 144 |
# 写入临时配置
|
| 145 |
-
|
| 146 |
-
os.makedirs(tmp_dir, exist_ok=True)
|
| 147 |
-
tmp_config_path = f"{tmp_dir}/tmp_s1.yaml"
|
| 148 |
-
with open(tmp_config_path, "w") as f:
|
| 149 |
f.write(yaml.dump(data, default_flow_style=False))
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
|
|
|
| 8 |
|
| 9 |
import json
|
| 10 |
import os
|
| 11 |
+
import tempfile
|
| 12 |
from typing import Dict, Any, Generator
|
| 13 |
|
| 14 |
import yaml
|
|
|
|
| 75 |
data["version"] = version_str
|
| 76 |
|
| 77 |
# 写入临时配置
|
| 78 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
|
|
|
|
|
|
|
|
| 79 |
f.write(json.dumps(data))
|
| 80 |
+
tmp_config_path = f.name
|
| 81 |
|
| 82 |
+
try:
|
| 83 |
+
# 选择训练脚本
|
| 84 |
+
if version_str in ["v1", "v2", "v2Pro", "v2ProPlus"]:
|
| 85 |
+
cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train.py --config "{tmp_config_path}"'
|
| 86 |
+
else:
|
| 87 |
+
cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train_v3_lora.py --config "{tmp_config_path}"'
|
| 88 |
+
|
| 89 |
+
yield self._make_progress("SoVITS训练启动中...", 0.1)
|
| 90 |
|
| 91 |
+
self._process = self._run_command(cmd, wait=True)
|
| 92 |
+
self._process = None
|
| 93 |
+
|
| 94 |
+
self._status = StageStatus.COMPLETED
|
| 95 |
+
yield self._make_progress("SoVITS训练完成", 1.0)
|
| 96 |
+
finally:
|
| 97 |
+
if os.path.exists(tmp_config_path):
|
| 98 |
+
os.remove(tmp_config_path)
|
| 99 |
|
| 100 |
|
| 101 |
class GPTTrainStage(BaseStage):
|
|
|
|
| 145 |
data["output_dir"] = f"{s1_dir}/logs_s1_{version_str}"
|
| 146 |
|
| 147 |
# 写入临时配置
|
| 148 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
|
|
|
|
|
|
|
|
| 149 |
f.write(yaml.dump(data, default_flow_style=False))
|
| 150 |
+
tmp_config_path = f.name
|
| 151 |
|
| 152 |
+
try:
|
| 153 |
+
# 设置环境变量
|
| 154 |
+
os.environ["_CUDA_VISIBLE_DEVICES"] = cfg.gpu_numbers.replace("-", ",")
|
| 155 |
+
os.environ["hz"] = "25hz"
|
| 156 |
+
|
| 157 |
+
cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s1_train.py --config_file "{tmp_config_path}"'
|
| 158 |
+
|
| 159 |
+
yield self._make_progress("GPT训练启动中...", 0.1)
|
| 160 |
+
|
| 161 |
+
self._process = self._run_command(cmd, wait=True)
|
| 162 |
+
self._process = None
|
| 163 |
+
|
| 164 |
+
self._status = StageStatus.COMPLETED
|
| 165 |
+
yield self._make_progress("GPT训练完成", 1.0)
|
| 166 |
+
finally:
|
| 167 |
+
if os.path.exists(tmp_config_path):
|
| 168 |
+
os.remove(tmp_config_path)
|
| 169 |
|