MarcusSu1216 commited on
Commit
5ec3a59
1 Parent(s): cf3daeb

Update preprocess_flist_config.py

Browse files
Files changed (1) hide show
  1. preprocess_flist_config.py +14 -6
preprocess_flist_config.py CHANGED
@@ -7,7 +7,7 @@ from random import shuffle
7
  import json
8
  import wave
9
 
10
- config_template = json.load(open("configs_template/config_template.json"))
11
 
12
  pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
13
 
@@ -25,11 +25,13 @@ if __name__ == "__main__":
25
  parser = argparse.ArgumentParser()
26
  parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
27
  parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
 
28
  parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
29
  args = parser.parse_args()
30
 
31
  train = []
32
  val = []
 
33
  idx = 0
34
  spk_dict = {}
35
  spk_id = 0
@@ -41,19 +43,21 @@ if __name__ == "__main__":
41
  for file in wavs:
42
  if not file.endswith("wav"):
43
  continue
44
- #if not pattern.match(file):
45
- # print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
46
  if get_wav_duration(file) < 0.3:
47
  print("skip too short audio:", file)
48
  continue
49
  new_wavs.append(file)
50
  wavs = new_wavs
51
  shuffle(wavs)
52
- train += wavs[2:]
53
  val += wavs[:2]
 
54
 
55
  shuffle(train)
56
  shuffle(val)
 
57
 
58
  print("Writing", args.train_list)
59
  with open(args.train_list, "w") as f:
@@ -66,10 +70,14 @@ if __name__ == "__main__":
66
  for fname in tqdm(val):
67
  wavpath = fname
68
  f.write(wavpath + "\n")
 
 
 
 
 
 
69
 
70
  config_template["spk"] = spk_dict
71
- config_template["model"]["n_speakers"] = spk_id
72
-
73
  print("Writing configs/config.json")
74
  with open("configs/config.json", "w") as f:
75
  json.dump(config_template, f, indent=2)
 
7
  import json
8
  import wave
9
 
10
+ config_template = json.load(open("configs/config.json"))
11
 
12
  pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
13
 
 
25
  parser = argparse.ArgumentParser()
26
  parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
27
  parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
28
+ parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list")
29
  parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
30
  args = parser.parse_args()
31
 
32
  train = []
33
  val = []
34
+ test = []
35
  idx = 0
36
  spk_dict = {}
37
  spk_id = 0
 
43
  for file in wavs:
44
  if not file.endswith("wav"):
45
  continue
46
+ if not pattern.match(file):
47
+ print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
48
  if get_wav_duration(file) < 0.3:
49
  print("skip too short audio:", file)
50
  continue
51
  new_wavs.append(file)
52
  wavs = new_wavs
53
  shuffle(wavs)
54
+ train += wavs[2:-2]
55
  val += wavs[:2]
56
+ test += wavs[-2:]
57
 
58
  shuffle(train)
59
  shuffle(val)
60
+ shuffle(test)
61
 
62
  print("Writing", args.train_list)
63
  with open(args.train_list, "w") as f:
 
70
  for fname in tqdm(val):
71
  wavpath = fname
72
  f.write(wavpath + "\n")
73
+
74
+ print("Writing", args.test_list)
75
+ with open(args.test_list, "w") as f:
76
+ for fname in tqdm(test):
77
+ wavpath = fname
78
+ f.write(wavpath + "\n")
79
 
80
  config_template["spk"] = spk_dict
 
 
81
  print("Writing configs/config.json")
82
  with open("configs/config.json", "w") as f:
83
  json.dump(config_template, f, indent=2)