liumaolin commited on
Commit
f458b69
·
1 Parent(s): 8178cb9

Use `tempfile` to handle temporary configuration files and ensure cleanup in training stages

Browse files
Files changed (1) hide show
  1. training_pipeline/stages/training.py +38 -33
training_pipeline/stages/training.py CHANGED
@@ -8,6 +8,7 @@
8
 
9
  import json
10
  import os
 
11
  from typing import Dict, Any, Generator
12
 
13
  import yaml
@@ -74,25 +75,27 @@ class SoVITSTrainStage(BaseStage):
74
  data["version"] = version_str
75
 
76
  # 写入临时配置
77
- tmp_dir = os.path.join(os.getcwd(), "TEMP")
78
- os.makedirs(tmp_dir, exist_ok=True)
79
- tmp_config_path = f"{tmp_dir}/tmp_s2.json"
80
- with open(tmp_config_path, "w") as f:
81
  f.write(json.dumps(data))
 
82
 
83
- # 选择训练脚本
84
- if version_str in ["v1", "v2", "v2Pro", "v2ProPlus"]:
85
- cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train.py --config "{tmp_config_path}"'
86
- else:
87
- cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train_v3_lora.py --config "{tmp_config_path}"'
 
 
 
88
 
89
- yield self._make_progress("SoVITS训练启动中...", 0.1)
90
-
91
- self._process = self._run_command(cmd, wait=True)
92
- self._process = None
93
-
94
- self._status = StageStatus.COMPLETED
95
- yield self._make_progress("SoVITS训练完成", 1.0)
 
96
 
97
 
98
  class GPTTrainStage(BaseStage):
@@ -142,23 +145,25 @@ class GPTTrainStage(BaseStage):
142
  data["output_dir"] = f"{s1_dir}/logs_s1_{version_str}"
143
 
144
  # 写入临时配置
145
- tmp_dir = os.path.join(os.getcwd(), "TEMP")
146
- os.makedirs(tmp_dir, exist_ok=True)
147
- tmp_config_path = f"{tmp_dir}/tmp_s1.yaml"
148
- with open(tmp_config_path, "w") as f:
149
  f.write(yaml.dump(data, default_flow_style=False))
 
150
 
151
- # 设置环境变量
152
- os.environ["_CUDA_VISIBLE_DEVICES"] = cfg.gpu_numbers.replace("-", ",")
153
- os.environ["hz"] = "25hz"
154
-
155
- cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s1_train.py --config_file "{tmp_config_path}"'
156
-
157
- yield self._make_progress("GPT训练启动中...", 0.1)
158
-
159
- self._process = self._run_command(cmd, wait=True)
160
- self._process = None
161
-
162
- self._status = StageStatus.COMPLETED
163
- yield self._make_progress("GPT训练完成", 1.0)
 
 
 
 
164
 
 
8
 
9
  import json
10
  import os
11
+ import tempfile
12
  from typing import Dict, Any, Generator
13
 
14
  import yaml
 
75
  data["version"] = version_str
76
 
77
  # 写入临时配置
78
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
 
 
 
79
  f.write(json.dumps(data))
80
+ tmp_config_path = f.name
81
 
82
+ try:
83
+ # 选择训练脚本
84
+ if version_str in ["v1", "v2", "v2Pro", "v2ProPlus"]:
85
+ cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train.py --config "{tmp_config_path}"'
86
+ else:
87
+ cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train_v3_lora.py --config "{tmp_config_path}"'
88
+
89
+ yield self._make_progress("SoVITS训练启动中...", 0.1)
90
 
91
+ self._process = self._run_command(cmd, wait=True)
92
+ self._process = None
93
+
94
+ self._status = StageStatus.COMPLETED
95
+ yield self._make_progress("SoVITS训练完成", 1.0)
96
+ finally:
97
+ if os.path.exists(tmp_config_path):
98
+ os.remove(tmp_config_path)
99
 
100
 
101
  class GPTTrainStage(BaseStage):
 
145
  data["output_dir"] = f"{s1_dir}/logs_s1_{version_str}"
146
 
147
  # 写入临时配置
148
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
 
 
 
149
  f.write(yaml.dump(data, default_flow_style=False))
150
+ tmp_config_path = f.name
151
 
152
+ try:
153
+ # 设置环境变量
154
+ os.environ["_CUDA_VISIBLE_DEVICES"] = cfg.gpu_numbers.replace("-", ",")
155
+ os.environ["hz"] = "25hz"
156
+
157
+ cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s1_train.py --config_file "{tmp_config_path}"'
158
+
159
+ yield self._make_progress("GPT训练启动中...", 0.1)
160
+
161
+ self._process = self._run_command(cmd, wait=True)
162
+ self._process = None
163
+
164
+ self._status = StageStatus.COMPLETED
165
+ yield self._make_progress("GPT训练完成", 1.0)
166
+ finally:
167
+ if os.path.exists(tmp_config_path):
168
+ os.remove(tmp_config_path)
169