Mahiruoshi commited on
Commit
b60adb9
1 Parent(s): 6045186

Upload classify.py

Browse files
Files changed (1) hide show
  1. classify.py +180 -0
classify.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.cluster import *
2
+ import os
3
+ import numpy as np
4
+ from config import config
5
+ import yaml
6
+ import argparse
7
+ import shutil
8
+
9
+ def ensure_dir(directory):
10
+ if not os.path.exists(directory):
11
+ os.makedirs(directory)
12
+
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("-a", "--algorithm", default="k", help="choose algorithm", type=str)
16
+ parser.add_argument("-n", "--num_clusters", default=4, help="number of clusters", type=int)
17
+ parser.add_argument("-r", "--range", default=4, help="number of files in a class", type=int)
18
+ args = parser.parse_args()
19
+
20
+ filelist_dict = {}
21
+ yml_result = {}
22
+ base_dir = "D:/Vits2/Bert-VITS2/Data/BanGDream/filelists"
23
+ output_dir = "D:/Vits2/classifedSample"
24
+
25
+ with open(os.path.join(base_dir, "Mygo.list"), mode="r", encoding="utf-8") as f:
26
+ embs = []
27
+ wavnames = []
28
+ for line in f:
29
+ parts = line.strip().split("|")
30
+ speaker = parts[1] # 假设 speaker 是第二个部分
31
+ filepath = parts[0] # 假设 filepath 是第一个部分
32
+ # ... 其余部分可以根据需要使用
33
+
34
+ if speaker not in filelist_dict:
35
+ filelist_dict[speaker] = []
36
+ yml_result[speaker] = {}
37
+ filelist_dict[speaker].append(filepath)
38
+
39
+ for speaker in filelist_dict:
40
+ print("\nspeaker: " + speaker)
41
+
42
+ embs = []
43
+ wavnames = []
44
+
45
+ for file in filelist_dict[speaker]:
46
+ try:
47
+ embs.append(np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0))
48
+ wavnames.append(file)
49
+ except Exception as e:
50
+ print(e)
51
+
52
+ if embs:
53
+ n_clusters = args.num_clusters
54
+ x = np.concatenate(embs, axis=0)
55
+ x = np.squeeze(x)
56
+
57
+ if args.algorithm == "b":
58
+ model = Birch(n_clusters=n_clusters, threshold=0.2)
59
+ elif args.algorithm == "s":
60
+ model = SpectralClustering(n_clusters=n_clusters)
61
+ elif args.algorithm == "a":
62
+ model = AgglomerativeClustering(n_clusters=n_clusters)
63
+ else:
64
+ model = KMeans(n_clusters=n_clusters, random_state=10)
65
+
66
+ y_predict = model.fit_predict(x)
67
+ classes = [[] for i in range(y_predict.max() + 1)]
68
+
69
+ for idx, wavname in enumerate(wavnames):
70
+ classes[y_predict[idx]].append(wavname)
71
+
72
+ for i in range(y_predict.max() + 1):
73
+ print("类别:", i, "本类中样本数量:", len(classes[i]))
74
+ yml_result[speaker][f"class{i}"] = []
75
+ class_dir = os.path.join(output_dir, speaker, f"class{i}")
76
+
77
+ num_samples_in_class = len(classes[i])
78
+ for j in range(min(args.range, num_samples_in_class)):
79
+ wav_file = classes[i][j]
80
+ print(wav_file)
81
+
82
+ # 复制文件到新目录
83
+ ensure_dir(class_dir)
84
+ shutil.copy(os.path.join(base_dir, wav_file), os.path.join(class_dir, os.path.basename(wav_file)))
85
+
86
+ yml_result[speaker][f"class{i}"].append(wav_file)
87
+
88
+ with open(os.path.join(base_dir, "emo_clustering.yml"), "w", encoding="utf-8") as f:
89
+ yaml.dump(yml_result, f)
90
+
91
+ '''
92
+ from sklearn.cluster import *
93
+ import os
94
+ import numpy as np
95
+ from config import config
96
+ import yaml
97
+ import argparse
98
+
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser()
102
+ parser.add_argument(
103
+ "-a", "--algorithm", default="s", help="choose algorithm", type=str
104
+ )
105
+ parser.add_argument(
106
+ "-n", "--num_clusters", default=3, help="number of clusters", type=int
107
+ )
108
+ parser.add_argument(
109
+ "-r", "--range", default=4, help="number of files in a class", type=int
110
+ )
111
+ args = parser.parse_args()
112
+ filelist_dict = {}
113
+ yml_result = {}
114
+ with open(
115
+ "D:/Vits2/Bert-VITS2/Data/BanGDream/filelists/Mygo.list", mode="r", encoding="utf-8"
116
+ ) as f:
117
+ embs = []
118
+ wavnames = []
119
+ for line in f:
120
+ speaker = line.split("|")[1]
121
+ if speaker not in filelist_dict:
122
+ filelist_dict[speaker] = []
123
+ yml_result[speaker] = {}
124
+ filelist_dict[speaker].append(line.split("|")[0])
125
+ #print(filelist_dict)
126
+
127
+ for speaker in filelist_dict:
128
+ print("\nspeaker: " + speaker)
129
+
130
+ # 清空 embs 和 wavnames 列表
131
+ embs = []
132
+ wavnames = []
133
+
134
+ for file in filelist_dict[speaker]:
135
+ try:
136
+ embs.append(
137
+ np.expand_dims(
138
+ np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0
139
+ )
140
+ )
141
+ wavnames.append(os.path.basename(file))
142
+ except Exception as e:
143
+ print(e)
144
+
145
+ if embs:
146
+ # 聚类算法类的数量
147
+ n_clusters = args.num_clusters
148
+ x = np.concatenate(embs, axis=0)
149
+ x = np.squeeze(x)
150
+ # 聚类算法类的数量
151
+ n_clusters = args.num_clusters
152
+ if args.algorithm == "b":
153
+ model = Birch(n_clusters=n_clusters, threshold=0.2)
154
+ elif args.algorithm == "s":
155
+ model = SpectralClustering(n_clusters=n_clusters)
156
+ elif args.algorithm == "a":
157
+ model = AgglomerativeClustering(n_clusters=n_clusters)
158
+ else:
159
+ model = KMeans(n_clusters=n_clusters, random_state=10)
160
+ # 可以自行尝试各种不同的聚类算法
161
+ y_predict = model.fit_predict(x)
162
+ classes = [[] for i in range(y_predict.max() + 1)]
163
+
164
+ for idx, wavname in enumerate(wavnames):
165
+ classes[y_predict[idx]].append(wavname)
166
+
167
+ for i in range(y_predict.max() + 1):
168
+ print("类别:", i, "本类中样本数量:", len(classes[i]))
169
+ yml_result[speaker][f"class{i}"] = []
170
+
171
+ # 修正:确保不会尝试访问超出范围的元素
172
+ num_samples_in_class = len(classes[i])
173
+ for j in range(min(args.range, num_samples_in_class)):
174
+ print(classes[i][j])
175
+ yml_result[speaker][f"class{i}"].append(classes[i][j])
176
+ with open(
177
+ os.path.join('D:/Vits2/Bert-VITS2/Data/BanGDream', "emo_clustering.yml"), "w", encoding="utf-8"
178
+ ) as f:
179
+ yaml.dump(yml_result, f)
180
+ '''