litagin commited on
Commit
9ed5ec3
1 Parent(s): 6593afe

Upload style_gen.py

Browse files
Files changed (1) hide show
  1. style_gen.py +79 -17
style_gen.py CHANGED
@@ -1,6 +1,5 @@
1
  import argparse
2
- import concurrent.futures
3
- import sys
4
  import warnings
5
 
6
  import numpy as np
@@ -8,6 +7,8 @@ import torch
8
  from tqdm import tqdm
9
 
10
  import utils
 
 
11
  from config import config
12
 
13
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -19,14 +20,44 @@ device = torch.device(config.style_gen_config.device)
19
  inference.to(device)
20
 
21
 
22
- def extract_style_vector(wav_path):
 
 
 
 
 
 
 
23
  return inference(wav_path)
24
 
25
 
26
  def save_style_vector(wav_path):
27
- style_vec = extract_style_vector(wav_path)
28
- # `test.wav` -> `test.wav.npy`
29
- np.save(f"{wav_path}.npy", style_vec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  if __name__ == "__main__":
@@ -45,22 +76,53 @@ if __name__ == "__main__":
45
 
46
  device = config.style_gen_config.device
47
 
48
- lines = []
49
  with open(hps.data.training_files, encoding="utf-8") as f:
50
- lines.extend(f.readlines())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
52
  with open(hps.data.validation_files, encoding="utf-8") as f:
53
- lines.extend(f.readlines())
54
-
55
- wavnames = [line.split("|")[0] for line in lines]
56
 
57
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_processes) as executor:
58
- list(
59
  tqdm(
60
- executor.map(save_style_vector, wavnames),
61
- total=len(wavnames),
62
- file=sys.stdout,
63
  )
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- print(f"Finished generating style vectors! total: {len(wavnames)} npy files.")
 
1
  import argparse
2
+ from concurrent.futures import ThreadPoolExecutor
 
3
  import warnings
4
 
5
  import numpy as np
 
7
  from tqdm import tqdm
8
 
9
  import utils
10
+ from common.log import logger
11
+ from common.stdout_wrapper import SAFE_STDOUT
12
  from config import config
13
 
14
  warnings.filterwarnings("ignore", category=UserWarning)
 
20
  inference.to(device)
21
 
22
 
23
+ class NaNValueError(ValueError):
24
+ """カスタム例外クラス。NaN値が見つかった場合に使用されます。"""
25
+
26
+ pass
27
+
28
+
29
+ # 推論時にインポートするために短いが関数を書く
30
+ def get_style_vector(wav_path):
31
  return inference(wav_path)
32
 
33
 
34
  def save_style_vector(wav_path):
35
+ try:
36
+ style_vec = get_style_vector(wav_path)
37
+ except Exception as e:
38
+ print("\n")
39
+ logger.error(f"Error occurred with file: {wav_path}, Details:\n{e}\n")
40
+ raise
41
+ # 値にNaNが含まれていると悪影響なのでチェックする
42
+ if np.isnan(style_vec).any():
43
+ print("\n")
44
+ logger.warning(f"NaN value found in style vector: {wav_path}")
45
+ raise NaNValueError(f"NaN value found in style vector: {wav_path}")
46
+ np.save(f"{wav_path}.npy", style_vec) # `test.wav` -> `test.wav.npy`
47
+
48
+
49
+ def process_line(line):
50
+ wavname = line.split("|")[0]
51
+ try:
52
+ save_style_vector(wavname)
53
+ return line, None
54
+ except NaNValueError:
55
+ return line, "nan_error"
56
+
57
+
58
+ def save_average_style_vector(style_vectors, filename="style_vectors.npy"):
59
+ average_vector = np.mean(style_vectors, axis=0)
60
+ np.save(filename, average_vector)
61
 
62
 
63
  if __name__ == "__main__":
 
76
 
77
  device = config.style_gen_config.device
78
 
79
+ training_lines = []
80
  with open(hps.data.training_files, encoding="utf-8") as f:
81
+ training_lines.extend(f.readlines())
82
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
83
+ training_results = list(
84
+ tqdm(
85
+ executor.map(process_line, training_lines),
86
+ total=len(training_lines),
87
+ file=SAFE_STDOUT,
88
+ )
89
+ )
90
+ ok_training_lines = [line for line, error in training_results if error is None]
91
+ nan_training_lines = [
92
+ line for line, error in training_results if error == "nan_error"
93
+ ]
94
+ if nan_training_lines:
95
+ nan_files = [line.split("|")[0] for line in nan_training_lines]
96
+ logger.warning(
97
+ f"Found NaN value in {len(nan_training_lines)} files: {nan_files}, so they will be deleted from training data."
98
+ )
99
 
100
+ val_lines = []
101
  with open(hps.data.validation_files, encoding="utf-8") as f:
102
+ val_lines.extend(f.readlines())
 
 
103
 
104
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
105
+ val_results = list(
106
  tqdm(
107
+ executor.map(process_line, val_lines),
108
+ total=len(val_lines),
109
+ file=SAFE_STDOUT,
110
  )
111
  )
112
+ ok_val_lines = [line for line, error in val_results if error is None]
113
+ nan_val_lines = [line for line, error in val_results if error == "nan_error"]
114
+ if nan_val_lines:
115
+ nan_files = [line.split("|")[0] for line in nan_val_lines]
116
+ logger.warning(
117
+ f"Found NaN value in {len(nan_val_lines)} files: {nan_files}, so they will be deleted from validation data."
118
+ )
119
+
120
+ with open(hps.data.training_files, "w", encoding="utf-8") as f:
121
+ f.writelines(ok_training_lines)
122
+
123
+ with open(hps.data.validation_files, "w", encoding="utf-8") as f:
124
+ f.writelines(ok_val_lines)
125
+
126
+ ok_num = len(ok_training_lines) + len(ok_val_lines)
127
 
128
+ logger.info(f"Finished generating style vectors! total: {ok_num} npy files.")