monet-joe commited on
Commit
b9c341a
1 Parent(s): 88b106b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +182 -183
  2. model.py +144 -158
  3. utils.py +58 -105
app.py CHANGED
@@ -1,183 +1,182 @@
1
- import os
2
- import torch
3
- import shutil
4
- import librosa
5
- import numpy as np
6
- import gradio as gr
7
- import librosa.display
8
- import matplotlib.pyplot as plt
9
- import torchvision.transforms as transforms
10
- from collections import Counter
11
- from model import EvalNet
12
- from PIL import Image
13
- from utils import *
14
- import warnings
15
-
16
- warnings.filterwarnings("ignore")
17
-
18
- classes = ["m_bel", "f_bel", "m_folk", "f_folk"]
19
-
20
-
21
- def most_common_element(input_list):
22
- # 使用 Counter 统计每个元素的出现次数
23
- counter = Counter(input_list)
24
- # 使用 most_common 方法获取出现次数最多的元素
25
- most_common_element, _ = counter.most_common(1)[0]
26
- return most_common_element
27
-
28
-
29
- def wav_to_mel(audio_path: str, width=1.6, topdb=40):
30
- os.makedirs("./tmp", exist_ok=True)
31
- try:
32
- y, sr = librosa.load(audio_path, sr=48000)
33
- non_silents = librosa.effects.split(y, top_db=topdb)
34
- non_silent = np.concatenate([y[start:end] for start, end in non_silents])
35
- mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
36
- log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
37
- dur = librosa.get_duration(y=non_silent, sr=sr)
38
- total_frames = log_mel_spec.shape[1]
39
- step = int(width * total_frames / dur)
40
- count = int(total_frames / step)
41
- begin = int(0.5 * (total_frames - count * step))
42
- end = begin + step * count
43
- for i in range(begin, end, step):
44
- librosa.display.specshow(log_mel_spec[:, i : i + step])
45
- plt.axis("off")
46
- plt.savefig(
47
- f"./tmp/mel_{round(dur, 2)}_{i}.jpg",
48
- bbox_inches="tight",
49
- pad_inches=0.0,
50
- )
51
- plt.close()
52
-
53
- except Exception as e:
54
- print(f"Error converting {audio_path} : {e}")
55
-
56
-
57
- def wav_to_cqt(audio_path: str, width=1.6, topdb=40):
58
- os.makedirs("./tmp", exist_ok=True)
59
- try:
60
- y, sr = librosa.load(audio_path, sr=48000)
61
- non_silents = librosa.effects.split(y, top_db=topdb)
62
- non_silent = np.concatenate([y[start:end] for start, end in non_silents])
63
- cqt_spec = librosa.cqt(y=non_silent, sr=sr)
64
- log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec) ** 2, ref=np.max)
65
- dur = librosa.get_duration(y=non_silent, sr=sr)
66
- total_frames = log_cqt_spec.shape[1]
67
- step = int(width * total_frames / dur)
68
- count = int(total_frames / step)
69
- begin = int(0.5 * (total_frames - count * step))
70
- end = begin + step * count
71
- for i in range(begin, end, step):
72
- librosa.display.specshow(log_cqt_spec[:, i : i + step])
73
- plt.axis("off")
74
- plt.savefig(
75
- f"./tmp/cqt_{round(dur, 2)}_{i}.jpg",
76
- bbox_inches="tight",
77
- pad_inches=0.0,
78
- )
79
- plt.close()
80
-
81
- except Exception as e:
82
- print(f"Error converting {audio_path} : {e}")
83
-
84
-
85
- def wav_to_chroma(audio_path: str, width=1.6, topdb=40):
86
- os.makedirs("./tmp", exist_ok=True)
87
- try:
88
- y, sr = librosa.load(audio_path, sr=48000)
89
- non_silents = librosa.effects.split(y, top_db=topdb)
90
- non_silent = np.concatenate([y[start:end] for start, end in non_silents])
91
- chroma_spec = librosa.feature.chroma_stft(y=non_silent, sr=sr)
92
- log_chroma_spec = librosa.power_to_db(np.abs(chroma_spec) ** 2, ref=np.max)
93
- dur = librosa.get_duration(y=non_silent, sr=sr)
94
- total_frames = log_chroma_spec.shape[1]
95
- step = int(width * total_frames / dur)
96
- count = int(total_frames / step)
97
- begin = int(0.5 * (total_frames - count * step))
98
- end = begin + step * count
99
- for i in range(begin, end, step):
100
- librosa.display.specshow(log_chroma_spec[:, i : i + step])
101
- plt.axis("off")
102
- plt.savefig(
103
- f"./tmp/chroma_{round(dur, 2)}_{i}.jpg",
104
- bbox_inches="tight",
105
- pad_inches=0.0,
106
- )
107
- plt.close()
108
-
109
- except Exception as e:
110
- print(f"Error converting {audio_path} : {e}")
111
-
112
-
113
- def embed_img(img_path, input_size=224):
114
- transform = transforms.Compose(
115
- [
116
- transforms.Resize([input_size, input_size]),
117
- transforms.ToTensor(),
118
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
119
- ]
120
- )
121
- img = Image.open(img_path).convert("RGB")
122
- return transform(img).unsqueeze(0)
123
-
124
-
125
- def inference(wav_path, log_name, folder_path="./tmp"):
126
- if os.path.exists(folder_path):
127
- shutil.rmtree(folder_path)
128
-
129
- if not wav_path:
130
- wav_path = "./examples/f_bel.wav"
131
-
132
- model = EvalNet(log_name).model
133
- spec = log_name.split("_")[-3]
134
- eval("wav_to_%s" % spec)(wav_path)
135
- outputs = []
136
- all_files = os.listdir(folder_path)
137
- for file_name in all_files:
138
- if file_name.lower().endswith(".jpg"):
139
- file_path = os.path.join(folder_path, file_name)
140
- input = embed_img(file_path)
141
- output = model(input)
142
- pred_id = torch.max(output.data, 1)[1]
143
- outputs.append(pred_id)
144
-
145
- max_count_item = most_common_element(outputs)
146
- shutil.rmtree(folder_path)
147
- return os.path.basename(wav_path), translate[classes[max_count_item]]
148
-
149
-
150
- models = get_modelist()
151
-
152
- translate = {
153
- "m_bel": "male bel canto",
154
- "m_folk": "male folk singing",
155
- "f_bel": "female bel canto",
156
- "f_folk": "female folk singing",
157
- }
158
-
159
- examples = []
160
- example_wavs = find_wav_files()
161
- for wav in example_wavs:
162
- examples.append([wav, models[0]])
163
-
164
- with gr.Blocks() as demo:
165
- gr.Markdown(
166
- """
167
- **Please note: It may take longer to obtain recognition results when using the selected model for the first time, as downloading weights is required. Please be patient while waiting for the results.**
168
- """
169
- )
170
- gr.Interface(
171
- fn=inference,
172
- inputs=[
173
- gr.Audio(label="Upload audio", type="filepath"),
174
- gr.Dropdown(choices=models, label="Select model", value=models[0]),
175
- ],
176
- outputs=[
177
- gr.Textbox(label="Audio filename", show_copy_button=True),
178
- gr.Textbox(label="Singing method", show_copy_button=True),
179
- ],
180
- examples=examples,
181
- )
182
-
183
- demo.launch()
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import shutil
5
+ import librosa
6
+ import warnings
7
+ import numpy as np
8
+ import gradio as gr
9
+ import librosa.display
10
+ import matplotlib.pyplot as plt
11
+ import torchvision.transforms as transforms
12
+ from utils import get_modelist, find_wav_files
13
+ from collections import Counter
14
+ from model import EvalNet
15
+ from PIL import Image
16
+
17
+
18
+ CLASSES = ["m_bel", "f_bel", "m_folk", "f_folk"]
19
+
20
+
21
+ def most_common_element(input_list):
22
+ # 使用 Counter 统计每个元素的出现次数
23
+ counter = Counter(input_list)
24
+ # 使用 most_common 方法获取出现次数最多的元素
25
+ most_common_element, _ = counter.most_common(1)[0]
26
+ return most_common_element
27
+
28
+
29
+ def wav_to_mel(audio_path: str, width=1.6, topdb=40):
30
+ os.makedirs("./tmp", exist_ok=True)
31
+ try:
32
+ y, sr = librosa.load(audio_path, sr=48000)
33
+ non_silents = librosa.effects.split(y, top_db=topdb)
34
+ non_silent = np.concatenate([y[start:end] for start, end in non_silents])
35
+ mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
36
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
37
+ dur = librosa.get_duration(y=non_silent, sr=sr)
38
+ total_frames = log_mel_spec.shape[1]
39
+ step = int(width * total_frames / dur)
40
+ count = int(total_frames / step)
41
+ begin = int(0.5 * (total_frames - count * step))
42
+ end = begin + step * count
43
+ for i in range(begin, end, step):
44
+ librosa.display.specshow(log_mel_spec[:, i : i + step])
45
+ plt.axis("off")
46
+ plt.savefig(
47
+ f"./tmp/mel_{round(dur, 2)}_{i}.jpg",
48
+ bbox_inches="tight",
49
+ pad_inches=0.0,
50
+ )
51
+ plt.close()
52
+
53
+ except Exception as e:
54
+ print(f"Error converting {audio_path} : {e}")
55
+
56
+
57
+ def wav_to_cqt(audio_path: str, width=1.6, topdb=40):
58
+ os.makedirs("./tmp", exist_ok=True)
59
+ try:
60
+ y, sr = librosa.load(audio_path, sr=48000)
61
+ non_silents = librosa.effects.split(y, top_db=topdb)
62
+ non_silent = np.concatenate([y[start:end] for start, end in non_silents])
63
+ cqt_spec = librosa.cqt(y=non_silent, sr=sr)
64
+ log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec) ** 2, ref=np.max)
65
+ dur = librosa.get_duration(y=non_silent, sr=sr)
66
+ total_frames = log_cqt_spec.shape[1]
67
+ step = int(width * total_frames / dur)
68
+ count = int(total_frames / step)
69
+ begin = int(0.5 * (total_frames - count * step))
70
+ end = begin + step * count
71
+ for i in range(begin, end, step):
72
+ librosa.display.specshow(log_cqt_spec[:, i : i + step])
73
+ plt.axis("off")
74
+ plt.savefig(
75
+ f"./tmp/cqt_{round(dur, 2)}_{i}.jpg",
76
+ bbox_inches="tight",
77
+ pad_inches=0.0,
78
+ )
79
+ plt.close()
80
+
81
+ except Exception as e:
82
+ print(f"Error converting {audio_path} : {e}")
83
+
84
+
85
+ def wav_to_chroma(audio_path: str, width=1.6, topdb=40):
86
+ os.makedirs("./tmp", exist_ok=True)
87
+ try:
88
+ y, sr = librosa.load(audio_path, sr=48000)
89
+ non_silents = librosa.effects.split(y, top_db=topdb)
90
+ non_silent = np.concatenate([y[start:end] for start, end in non_silents])
91
+ chroma_spec = librosa.feature.chroma_stft(y=non_silent, sr=sr)
92
+ log_chroma_spec = librosa.power_to_db(np.abs(chroma_spec) ** 2, ref=np.max)
93
+ dur = librosa.get_duration(y=non_silent, sr=sr)
94
+ total_frames = log_chroma_spec.shape[1]
95
+ step = int(width * total_frames / dur)
96
+ count = int(total_frames / step)
97
+ begin = int(0.5 * (total_frames - count * step))
98
+ end = begin + step * count
99
+ for i in range(begin, end, step):
100
+ librosa.display.specshow(log_chroma_spec[:, i : i + step])
101
+ plt.axis("off")
102
+ plt.savefig(
103
+ f"./tmp/chroma_{round(dur, 2)}_{i}.jpg",
104
+ bbox_inches="tight",
105
+ pad_inches=0.0,
106
+ )
107
+ plt.close()
108
+
109
+ except Exception as e:
110
+ print(f"Error converting {audio_path} : {e}")
111
+
112
+
113
+ def embed_img(img_path, input_size=224):
114
+ transform = transforms.Compose(
115
+ [
116
+ transforms.Resize([input_size, input_size]),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
119
+ ]
120
+ )
121
+ img = Image.open(img_path).convert("RGB")
122
+ return transform(img).unsqueeze(0)
123
+
124
+
125
+ def inference(wav_path: str, log_name: str, folder_path="./tmp"):
126
+ if os.path.exists(folder_path):
127
+ shutil.rmtree(folder_path)
128
+
129
+ if not wav_path:
130
+ wav_path = "./examples/f_bel.wav"
131
+
132
+ model = EvalNet(log_name).model
133
+ spec = log_name.split("_")[-1]
134
+ eval("wav_to_%s" % spec)(wav_path)
135
+ outputs = []
136
+ all_files = os.listdir(folder_path)
137
+ for file_name in all_files:
138
+ if file_name.lower().endswith(".jpg"):
139
+ file_path = os.path.join(folder_path, file_name)
140
+ input = embed_img(file_path)
141
+ output = model(input)
142
+ pred_id = torch.max(output.data, 1)[1]
143
+ outputs.append(pred_id)
144
+
145
+ max_count_item = most_common_element(outputs)
146
+ shutil.rmtree(folder_path)
147
+ return os.path.basename(wav_path), translate[CLASSES[max_count_item]]
148
+
149
+
150
+ if __name__ == "__main__":
151
+ warnings.filterwarnings("ignore")
152
+
153
+ models = get_modelist()
154
+ translate = {
155
+ "m_bel": "男声美声唱法",
156
+ "m_folk": "男声民族唱法",
157
+ "f_bel": "女声美声唱法",
158
+ "f_folk": "女声民族唱法",
159
+ }
160
+ examples = []
161
+ example_wavs = find_wav_files()
162
+ model_num = len(models)
163
+ for wav in example_wavs:
164
+ examples.append([wav, models[random.randint(0, model_num - 1)]])
165
+
166
+ with gr.Blocks() as demo:
167
+ gr.Interface(
168
+ fn=inference,
169
+ inputs=[
170
+ gr.Audio(label="上传录音", type="filepath"),
171
+ gr.Dropdown(choices=models, label="选择模型", value=models[0]),
172
+ ],
173
+ outputs=[
174
+ gr.Textbox(label="音频文件名", show_copy_button=True),
175
+ gr.Textbox(label="唱法识别", show_copy_button=True),
176
+ ],
177
+ examples=examples,
178
+ allow_flagging="never",
179
+ title="建议录音时长保持在 5s 左右, 过长会影响识别效率",
180
+ )
181
+
182
+ demo.launch()
 
model.py CHANGED
@@ -1,158 +1,144 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.models as models
5
- from datasets import load_dataset
6
- from utils import url_download, DOMAIN
7
-
8
-
9
- def get_backbone(ver, backbone_list):
10
- for bb in backbone_list:
11
- if ver == bb["ver"]:
12
- return bb
13
-
14
- print("Backbone name not found, using default option - alexnet.")
15
- return backbone_list[0]
16
-
17
-
18
- def model_info(m_ver):
19
- backbone_list = load_dataset(
20
- "monet-joe/cv_backbones",
21
- split="IMAGENET1K_V1",
22
- trust_remote_code=True,
23
- )
24
- backbone = get_backbone(m_ver, backbone_list)
25
- m_type = str(backbone["type"])
26
- input_size = int(backbone["input_size"])
27
- return m_type, input_size
28
-
29
-
30
- def download_model(log_name="vit_b_16_mel_2024-01-07_05-16-24"):
31
- pre_model_url = f"{DOMAIN}?Revision=master&FilePath=released/{log_name}/save.pt"
32
- pre_model_path = f"./model/{log_name}.pt"
33
- m_ver = "_".join(log_name.split("_")[:-3])
34
- os.makedirs("./model", exist_ok=True)
35
-
36
- if not os.path.exists(pre_model_path):
37
- url_download(pre_model_url, pre_model_path)
38
-
39
- return pre_model_path, m_ver
40
-
41
-
42
- def Classifier(cls_num: int, output_size: int, linear_output: bool):
43
- q = (1.0 * output_size / cls_num) ** 0.25
44
- l1 = int(q * cls_num)
45
- l2 = int(q * l1)
46
- l3 = int(q * l2)
47
-
48
- if linear_output:
49
- return torch.nn.Sequential(
50
- nn.Dropout(),
51
- nn.Linear(output_size, l3),
52
- nn.ReLU(inplace=True),
53
- nn.Dropout(),
54
- nn.Linear(l3, l2),
55
- nn.ReLU(inplace=True),
56
- nn.Dropout(),
57
- nn.Linear(l2, l1),
58
- nn.ReLU(inplace=True),
59
- nn.Linear(l1, cls_num),
60
- )
61
-
62
- else:
63
- return torch.nn.Sequential(
64
- nn.Dropout(),
65
- nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
66
- nn.ReLU(inplace=True),
67
- nn.AdaptiveAvgPool2d(output_size=(1, 1)),
68
- nn.Flatten(),
69
- nn.Linear(l3, l2),
70
- nn.ReLU(inplace=True),
71
- nn.Dropout(),
72
- nn.Linear(l2, l1),
73
- nn.ReLU(inplace=True),
74
- nn.Linear(l1, cls_num),
75
- )
76
-
77
-
78
- class EvalNet:
79
- model = None
80
- m_type = "squeezenet"
81
- input_size = 224
82
- output_size = 512
83
-
84
- def __init__(self, log_name, cls_num=4):
85
- saved_model_path, m_ver = download_model(log_name)
86
- self.m_type, self.input_size = model_info(m_ver)
87
-
88
- if not hasattr(models, m_ver):
89
- print("Unsupported model.")
90
- exit()
91
-
92
- self.model = eval("models.%s()" % m_ver)
93
- linear_output = self._set_outsize()
94
- self._set_classifier(cls_num, linear_output)
95
- checkpoint = torch.load(saved_model_path, map_location="cpu")
96
- if torch.cuda.is_available():
97
- checkpoint = torch.load(saved_model_path)
98
-
99
- self.model.load_state_dict(checkpoint, False)
100
- self.model.eval()
101
-
102
- def _set_outsize(self, debug_mode=False):
103
- for name, module in self.model.named_modules():
104
- if (
105
- str(name).__contains__("classifier")
106
- or str(name).__eq__("fc")
107
- or str(name).__contains__("head")
108
- ):
109
- if isinstance(module, torch.nn.Linear):
110
- self.output_size = module.in_features
111
- if debug_mode:
112
- print(
113
- f"{name}(Linear): {self.output_size} -> {module.out_features}"
114
- )
115
- return True
116
-
117
- if isinstance(module, torch.nn.Conv2d):
118
- self.output_size = module.in_channels
119
- if debug_mode:
120
- print(
121
- f"{name}(Conv2d): {self.output_size} -> {module.out_channels}"
122
- )
123
- return False
124
-
125
- return False
126
-
127
- def _set_classifier(self, cls_num, linear_output):
128
- if self.m_type == "convnext":
129
- del self.model.classifier[2]
130
- self.model.classifier = nn.Sequential(
131
- *list(self.model.classifier)
132
- + list(Classifier(cls_num, self.output_size, linear_output))
133
- )
134
- return
135
-
136
- if hasattr(self.model, "classifier"):
137
- self.model.classifier = Classifier(cls_num, self.output_size, linear_output)
138
- return
139
-
140
- elif hasattr(self.model, "fc"):
141
- self.model.fc = Classifier(cls_num, self.output_size, linear_output)
142
- return
143
-
144
- elif hasattr(self.model, "head"):
145
- self.model.head = Classifier(cls_num, self.output_size, linear_output)
146
- return
147
-
148
- self.model.heads.head = Classifier(cls_num, self.output_size, linear_output)
149
-
150
- def forward(self, x):
151
- if torch.cuda.is_available():
152
- x = x.cuda()
153
- self.model = self.model.cuda()
154
-
155
- if self.m_type == "googlenet" and self.training:
156
- return self.model(x)[0]
157
- else:
158
- return self.model(x)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from modelscope.msdatasets import MsDataset
5
+ from utils import MODEL_DIR
6
+
7
+
8
+ def get_backbone(ver, backbone_list):
9
+ for bb in backbone_list:
10
+ if ver == bb["ver"]:
11
+ return bb
12
+
13
+ print("Backbone name not found, using default option - alexnet.")
14
+ return backbone_list[0]
15
+
16
+
17
+ def model_info(m_ver):
18
+ backbone_list = MsDataset.load(
19
+ "monetjoe/cv_backbones", subset_name="ImageNet1k_v1", split="train"
20
+ )
21
+ backbone = get_backbone(m_ver, backbone_list)
22
+ m_type = str(backbone["type"])
23
+ input_size = int(backbone["input_size"])
24
+ return m_type, input_size
25
+
26
+
27
+ def Classifier(cls_num: int, output_size: int, linear_output: bool):
28
+ q = (1.0 * output_size / cls_num) ** 0.25
29
+ l1 = int(q * cls_num)
30
+ l2 = int(q * l1)
31
+ l3 = int(q * l2)
32
+
33
+ if linear_output:
34
+ return torch.nn.Sequential(
35
+ nn.Dropout(),
36
+ nn.Linear(output_size, l3),
37
+ nn.ReLU(inplace=True),
38
+ nn.Dropout(),
39
+ nn.Linear(l3, l2),
40
+ nn.ReLU(inplace=True),
41
+ nn.Dropout(),
42
+ nn.Linear(l2, l1),
43
+ nn.ReLU(inplace=True),
44
+ nn.Linear(l1, cls_num),
45
+ )
46
+
47
+ else:
48
+ return torch.nn.Sequential(
49
+ nn.Dropout(),
50
+ nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
51
+ nn.ReLU(inplace=True),
52
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
53
+ nn.Flatten(),
54
+ nn.Linear(l3, l2),
55
+ nn.ReLU(inplace=True),
56
+ nn.Dropout(),
57
+ nn.Linear(l2, l1),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear(l1, cls_num),
60
+ )
61
+
62
+
63
+ class EvalNet:
64
+ model = None
65
+ m_type = "squeezenet"
66
+ input_size = 224
67
+ output_size = 512
68
+
69
+ def __init__(self, log_name: str, cls_num=4):
70
+ saved_model_path = f"{MODEL_DIR}/{log_name}/save.pt"
71
+ m_ver = "_".join(log_name.split("_")[:-1])
72
+ self.m_type, self.input_size = model_info(m_ver)
73
+
74
+ if not hasattr(models, m_ver):
75
+ print("Unsupported model.")
76
+ exit()
77
+
78
+ self.model = eval("models.%s()" % m_ver)
79
+ linear_output = self._set_outsize()
80
+ self._set_classifier(cls_num, linear_output)
81
+ checkpoint = torch.load(saved_model_path, map_location="cpu")
82
+ if torch.cuda.is_available():
83
+ checkpoint = torch.load(saved_model_path)
84
+
85
+ self.model.load_state_dict(checkpoint, False)
86
+ self.model.eval()
87
+
88
+ def _set_outsize(self, debug_mode=False):
89
+ for name, module in self.model.named_modules():
90
+ if (
91
+ str(name).__contains__("classifier")
92
+ or str(name).__eq__("fc")
93
+ or str(name).__contains__("head")
94
+ ):
95
+ if isinstance(module, torch.nn.Linear):
96
+ self.output_size = module.in_features
97
+ if debug_mode:
98
+ print(
99
+ f"{name}(Linear): {self.output_size} -> {module.out_features}"
100
+ )
101
+ return True
102
+
103
+ if isinstance(module, torch.nn.Conv2d):
104
+ self.output_size = module.in_channels
105
+ if debug_mode:
106
+ print(
107
+ f"{name}(Conv2d): {self.output_size} -> {module.out_channels}"
108
+ )
109
+ return False
110
+
111
+ return False
112
+
113
+ def _set_classifier(self, cls_num, linear_output):
114
+ if self.m_type == "convnext":
115
+ del self.model.classifier[2]
116
+ self.model.classifier = nn.Sequential(
117
+ *list(self.model.classifier)
118
+ + list(Classifier(cls_num, self.output_size, linear_output))
119
+ )
120
+ return
121
+
122
+ if hasattr(self.model, "classifier"):
123
+ self.model.classifier = Classifier(cls_num, self.output_size, linear_output)
124
+ return
125
+
126
+ elif hasattr(self.model, "fc"):
127
+ self.model.fc = Classifier(cls_num, self.output_size, linear_output)
128
+ return
129
+
130
+ elif hasattr(self.model, "head"):
131
+ self.model.head = Classifier(cls_num, self.output_size, linear_output)
132
+ return
133
+
134
+ self.model.heads.head = Classifier(cls_num, self.output_size, linear_output)
135
+
136
+ def forward(self, x):
137
+ if torch.cuda.is_available():
138
+ x = x.cuda()
139
+ self.model = self.model.cuda()
140
+
141
+ if self.m_type == "googlenet" and self.training:
142
+ return self.model(x)[0]
143
+ else:
144
+ return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,105 +1,58 @@
1
- import os
2
- import time
3
- import torch
4
- import zipfile
5
- import requests
6
- from tqdm import tqdm
7
-
8
- DOMAIN = "https://www.modelscope.cn/api/v1/models/ccmusic/bel_canto/repo"
9
-
10
-
11
- def url_download(url: str, fname: str, max_retries=3):
12
- retry_count = 0
13
- while retry_count < max_retries:
14
- try:
15
- print(f"Downloading: {url}")
16
- resp = requests.get(url, stream=True)
17
- # Check the response status code (raise an exception if it's not in the range 200-299)
18
- resp.raise_for_status()
19
- total = int(resp.headers.get("content-length", 0))
20
- # create_dir(data_dir)
21
- with open(fname, "wb") as file, tqdm(
22
- desc=fname,
23
- total=total,
24
- unit="iB",
25
- unit_scale=True,
26
- unit_divisor=1024,
27
- ) as bar:
28
- for data in resp.iter_content(chunk_size=1024):
29
- size = file.write(data)
30
- bar.update(size)
31
- print(f"Download of {url} completed.")
32
- return
33
-
34
- except requests.exceptions.HTTPError as errh:
35
- print(f"HTTP error occurred: {errh}")
36
- retry_count += 1
37
- continue
38
- except requests.exceptions.ConnectionError as errc:
39
- print(f"Connection error occurred: {errc}")
40
- retry_count += 1
41
- continue
42
- except requests.exceptions.Timeout as errt:
43
- print(f"Timeout error occurred: {errt}")
44
- retry_count += 1
45
- continue
46
- except Exception as err:
47
- print(f"Other error occurred: {err}")
48
- retry_count += 1
49
- continue
50
-
51
- else:
52
- print(
53
- f"Error: the operation could not be completed after {max_retries} retries."
54
- )
55
- exit()
56
-
57
-
58
- def unzip_file(zip_src, dst_dir):
59
- r = zipfile.is_zipfile(zip_src)
60
- if r:
61
- fz = zipfile.ZipFile(zip_src, "r")
62
- for file in fz.namelist():
63
- fz.extract(file, dst_dir)
64
- else:
65
- print("This is not zip")
66
-
67
-
68
- def time_stamp(timestamp=None):
69
- if timestamp != None:
70
- return timestamp.strftime("%Y-%m-%d %H:%M:%S")
71
-
72
- return time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
73
-
74
-
75
- def toCUDA(x):
76
- if hasattr(x, "cuda"):
77
- if torch.cuda.is_available():
78
- return x.cuda()
79
-
80
- return x
81
-
82
-
83
- def find_wav_files(folder_path="./examples"):
84
- wav_files = []
85
- for root, _, files in os.walk(folder_path):
86
- for file in files:
87
- if file.endswith(".wav"):
88
- file_path = os.path.join(root, file)
89
- wav_files.append(file_path)
90
-
91
- return wav_files
92
-
93
-
94
- def get_modelist(url=f"{DOMAIN}/trees"):
95
- models = []
96
- response = requests.get(url)
97
- if response.status_code == 200:
98
- json_data = response.json()
99
- files = list(json_data["Data"]["Trees"])
100
- for file in files:
101
- if file["Name"] == "released":
102
- for item in list(file["Children"]):
103
- models.append(item["Name"])
104
-
105
- return models
 
1
+ import os
2
+ import time
3
+ import torch
4
+ from modelscope import snapshot_download
5
+
6
+ MODEL_DIR = snapshot_download("ccmusic/bel_canto")
7
+
8
+
9
+ def time_stamp(timestamp=None):
10
+ if timestamp != None:
11
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
12
+
13
+ return time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
14
+
15
+
16
+ def toCUDA(x):
17
+ if hasattr(x, "cuda"):
18
+ if torch.cuda.is_available():
19
+ return x.cuda()
20
+
21
+ return x
22
+
23
+
24
+ def find_wav_files(folder_path="./examples"):
25
+ wav_files = []
26
+ for root, _, files in os.walk(folder_path):
27
+ for file in files:
28
+ if file.endswith(".wav"):
29
+ file_path = os.path.join(root, file)
30
+ wav_files.append(file_path)
31
+
32
+ return wav_files
33
+
34
+
35
+ def get_modelist(model_dir=MODEL_DIR):
36
+ try:
37
+ entries = os.listdir(model_dir)
38
+ except OSError as e:
39
+ print(f"无法访问 {model_dir}: {e}")
40
+ return
41
+
42
+ # 遍历所有条目
43
+ output = []
44
+ for entry in entries:
45
+ # 获取完整路径
46
+ full_path = os.path.join(model_dir, entry)
47
+
48
+ # 跳过'.git'文件夹
49
+ if entry == ".git":
50
+ print(f"跳过 .git 文件夹: {full_path}")
51
+ continue
52
+
53
+ # 检查条目是文件还是目录
54
+ if os.path.isdir(full_path):
55
+ # 打印目录路径
56
+ output.append(os.path.basename(full_path))
57
+
58
+ return output