File size: 6,507 Bytes
b60adb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from sklearn.cluster import *
import os
import numpy as np
from config import config
import yaml
import argparse
import shutil

def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-a", "--algorithm", default="k", help="choose algorithm", type=str)
    parser.add_argument("-n", "--num_clusters", default=4, help="number of clusters", type=int)
    parser.add_argument("-r", "--range", default=4, help="number of files in a class", type=int)
    args = parser.parse_args()

    filelist_dict = {}
    yml_result = {}
    base_dir = "D:/Vits2/Bert-VITS2/Data/BanGDream/filelists"
    output_dir = "D:/Vits2/classifedSample"

    with open(os.path.join(base_dir, "Mygo.list"), mode="r", encoding="utf-8") as f:
        embs = []
        wavnames = []
        for line in f:
            parts = line.strip().split("|")
            speaker = parts[1]  # 假设 speaker 是第二个部分
            filepath = parts[0]  # 假设 filepath 是第一个部分
            # ... 其余部分可以根据需要使用

            if speaker not in filelist_dict:
                filelist_dict[speaker] = []
                yml_result[speaker] = {}
            filelist_dict[speaker].append(filepath)

    for speaker in filelist_dict:
        print("\nspeaker: " + speaker)

        embs = []
        wavnames = []

        for file in filelist_dict[speaker]:
            try:
                embs.append(np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0))
                wavnames.append(file)
            except Exception as e:
                print(e)

        if embs:
            n_clusters = args.num_clusters
            x = np.concatenate(embs, axis=0)
            x = np.squeeze(x)

            if args.algorithm == "b":
                model = Birch(n_clusters=n_clusters, threshold=0.2)
            elif args.algorithm == "s":
                model = SpectralClustering(n_clusters=n_clusters)
            elif args.algorithm == "a":
                model = AgglomerativeClustering(n_clusters=n_clusters)
            else:
                model = KMeans(n_clusters=n_clusters, random_state=10)

            y_predict = model.fit_predict(x)
            classes = [[] for i in range(y_predict.max() + 1)]

            for idx, wavname in enumerate(wavnames):
                classes[y_predict[idx]].append(wavname)

            for i in range(y_predict.max() + 1):
                print("类别:", i, "本类中样本数量:", len(classes[i]))
                yml_result[speaker][f"class{i}"] = []
                class_dir = os.path.join(output_dir, speaker, f"class{i}")

                num_samples_in_class = len(classes[i])
                for j in range(min(args.range, num_samples_in_class)):
                    wav_file = classes[i][j]
                    print(wav_file)

                    # 复制文件到新目录
                    ensure_dir(class_dir)
                    shutil.copy(os.path.join(base_dir, wav_file), os.path.join(class_dir, os.path.basename(wav_file)))

                    yml_result[speaker][f"class{i}"].append(wav_file)

    with open(os.path.join(base_dir, "emo_clustering.yml"), "w", encoding="utf-8") as f:
        yaml.dump(yml_result, f)

'''
from sklearn.cluster import *
import os
import numpy as np
from config import config
import yaml
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-a", "--algorithm", default="s", help="choose algorithm", type=str
    )
    parser.add_argument(
        "-n", "--num_clusters", default=3, help="number of clusters", type=int
    )
    parser.add_argument(
        "-r", "--range", default=4, help="number of files in a class", type=int
    )
    args = parser.parse_args()
    filelist_dict = {}
    yml_result = {}
    with open(
        "D:/Vits2/Bert-VITS2/Data/BanGDream/filelists/Mygo.list", mode="r", encoding="utf-8"
    ) as f:
        embs = []
        wavnames = []
        for line in f:
            speaker = line.split("|")[1]
            if speaker not in filelist_dict:
                filelist_dict[speaker] = []
                yml_result[speaker] = {}
            filelist_dict[speaker].append(line.split("|")[0])
    #print(filelist_dict)

    for speaker in filelist_dict:
        print("\nspeaker: " + speaker)

        # 清空 embs 和 wavnames 列表
        embs = []
        wavnames = []

        for file in filelist_dict[speaker]:
            try:
                embs.append(
                    np.expand_dims(
                        np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0
                    )
                )
                wavnames.append(os.path.basename(file))
            except Exception as e:
                print(e)

        if embs:
        # 聚类算法类的数量
            n_clusters = args.num_clusters
            x = np.concatenate(embs, axis=0)
            x = np.squeeze(x)
            # 聚类算法类的数量
            n_clusters = args.num_clusters
            if args.algorithm == "b":
                model = Birch(n_clusters=n_clusters, threshold=0.2)
            elif args.algorithm == "s":
                model = SpectralClustering(n_clusters=n_clusters)
            elif args.algorithm == "a":
                model = AgglomerativeClustering(n_clusters=n_clusters)
            else:
                model = KMeans(n_clusters=n_clusters, random_state=10)
            # 可以自行尝试各种不同的聚类算法
            y_predict = model.fit_predict(x)
            classes = [[] for i in range(y_predict.max() + 1)]

            for idx, wavname in enumerate(wavnames):
                classes[y_predict[idx]].append(wavname)

            for i in range(y_predict.max() + 1):
                print("类别:", i, "本类中样本数量:", len(classes[i]))
                yml_result[speaker][f"class{i}"] = []

                # 修正:确保不会尝试访问超出范围的元素
                num_samples_in_class = len(classes[i])
                for j in range(min(args.range, num_samples_in_class)):
                    print(classes[i][j])
                    yml_result[speaker][f"class{i}"].append(classes[i][j])
    with open(
        os.path.join('D:/Vits2/Bert-VITS2/Data/BanGDream', "emo_clustering.yml"), "w", encoding="utf-8"
    ) as f:
        yaml.dump(yml_result, f)
        '''