MyGO_VIts-bert / classify.py
Mahiruoshi's picture
Upload classify.py
b60adb9
from sklearn.cluster import *
import os
import numpy as np
from config import config
import yaml
import argparse
import shutil
def ensure_dir(directory):
if not os.path.exists(directory):
os.makedirs(directory)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--algorithm", default="k", help="choose algorithm", type=str)
parser.add_argument("-n", "--num_clusters", default=4, help="number of clusters", type=int)
parser.add_argument("-r", "--range", default=4, help="number of files in a class", type=int)
args = parser.parse_args()
filelist_dict = {}
yml_result = {}
base_dir = "D:/Vits2/Bert-VITS2/Data/BanGDream/filelists"
output_dir = "D:/Vits2/classifedSample"
with open(os.path.join(base_dir, "Mygo.list"), mode="r", encoding="utf-8") as f:
embs = []
wavnames = []
for line in f:
parts = line.strip().split("|")
speaker = parts[1] # 假设 speaker 是第二个部分
filepath = parts[0] # 假设 filepath 是第一个部分
# ... 其余部分可以根据需要使用
if speaker not in filelist_dict:
filelist_dict[speaker] = []
yml_result[speaker] = {}
filelist_dict[speaker].append(filepath)
for speaker in filelist_dict:
print("\nspeaker: " + speaker)
embs = []
wavnames = []
for file in filelist_dict[speaker]:
try:
embs.append(np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0))
wavnames.append(file)
except Exception as e:
print(e)
if embs:
n_clusters = args.num_clusters
x = np.concatenate(embs, axis=0)
x = np.squeeze(x)
if args.algorithm == "b":
model = Birch(n_clusters=n_clusters, threshold=0.2)
elif args.algorithm == "s":
model = SpectralClustering(n_clusters=n_clusters)
elif args.algorithm == "a":
model = AgglomerativeClustering(n_clusters=n_clusters)
else:
model = KMeans(n_clusters=n_clusters, random_state=10)
y_predict = model.fit_predict(x)
classes = [[] for i in range(y_predict.max() + 1)]
for idx, wavname in enumerate(wavnames):
classes[y_predict[idx]].append(wavname)
for i in range(y_predict.max() + 1):
print("类别:", i, "本类中样本数量:", len(classes[i]))
yml_result[speaker][f"class{i}"] = []
class_dir = os.path.join(output_dir, speaker, f"class{i}")
num_samples_in_class = len(classes[i])
for j in range(min(args.range, num_samples_in_class)):
wav_file = classes[i][j]
print(wav_file)
# 复制文件到新目录
ensure_dir(class_dir)
shutil.copy(os.path.join(base_dir, wav_file), os.path.join(class_dir, os.path.basename(wav_file)))
yml_result[speaker][f"class{i}"].append(wav_file)
with open(os.path.join(base_dir, "emo_clustering.yml"), "w", encoding="utf-8") as f:
yaml.dump(yml_result, f)
'''
from sklearn.cluster import *
import os
import numpy as np
from config import config
import yaml
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-a", "--algorithm", default="s", help="choose algorithm", type=str
)
parser.add_argument(
"-n", "--num_clusters", default=3, help="number of clusters", type=int
)
parser.add_argument(
"-r", "--range", default=4, help="number of files in a class", type=int
)
args = parser.parse_args()
filelist_dict = {}
yml_result = {}
with open(
"D:/Vits2/Bert-VITS2/Data/BanGDream/filelists/Mygo.list", mode="r", encoding="utf-8"
) as f:
embs = []
wavnames = []
for line in f:
speaker = line.split("|")[1]
if speaker not in filelist_dict:
filelist_dict[speaker] = []
yml_result[speaker] = {}
filelist_dict[speaker].append(line.split("|")[0])
#print(filelist_dict)
for speaker in filelist_dict:
print("\nspeaker: " + speaker)
# 清空 embs 和 wavnames 列表
embs = []
wavnames = []
for file in filelist_dict[speaker]:
try:
embs.append(
np.expand_dims(
np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0
)
)
wavnames.append(os.path.basename(file))
except Exception as e:
print(e)
if embs:
# 聚类算法类的数量
n_clusters = args.num_clusters
x = np.concatenate(embs, axis=0)
x = np.squeeze(x)
# 聚类算法类的数量
n_clusters = args.num_clusters
if args.algorithm == "b":
model = Birch(n_clusters=n_clusters, threshold=0.2)
elif args.algorithm == "s":
model = SpectralClustering(n_clusters=n_clusters)
elif args.algorithm == "a":
model = AgglomerativeClustering(n_clusters=n_clusters)
else:
model = KMeans(n_clusters=n_clusters, random_state=10)
# 可以自行尝试各种不同的聚类算法
y_predict = model.fit_predict(x)
classes = [[] for i in range(y_predict.max() + 1)]
for idx, wavname in enumerate(wavnames):
classes[y_predict[idx]].append(wavname)
for i in range(y_predict.max() + 1):
print("类别:", i, "本类中样本数量:", len(classes[i]))
yml_result[speaker][f"class{i}"] = []
# 修正:确保不会尝试访问超出范围的元素
num_samples_in_class = len(classes[i])
for j in range(min(args.range, num_samples_in_class)):
print(classes[i][j])
yml_result[speaker][f"class{i}"].append(classes[i][j])
with open(
os.path.join('D:/Vits2/Bert-VITS2/Data/BanGDream', "emo_clustering.yml"), "w", encoding="utf-8"
) as f:
yaml.dump(yml_result, f)
'''