MuGeminorum commited on
Commit
ea63f7b
1 Parent(s): 1e2d53f
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.pt
2
+ __pycache__/*
3
+ tmp/*
4
+ flagged/*
5
+ test.py
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import shutil
4
+ import librosa
5
+ import numpy as np
6
+ import gradio as gr
7
+ import librosa.display
8
+ import matplotlib.pyplot as plt
9
+ import torchvision.transforms as transforms
10
+ from collections import Counter
11
+ from model import EvalNet
12
+ from PIL import Image
13
+ from utils import *
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+
18
+ def most_common_element(input_list):
19
+ # 使用 Counter 统计每个元素的出现次数
20
+ counter = Counter(input_list)
21
+ # 使用 most_common 方法获取出现次数最多的元素
22
+ most_common_element, _ = counter.most_common(1)[0]
23
+ return most_common_element
24
+
25
+
26
+ def wav_to_mel(audio_path: str, width=0.07):
27
+ create_dir('./tmp')
28
+ try:
29
+ y, sr = librosa.load(audio_path, sr=48000)
30
+ mel_spec = librosa.feature.melspectrogram(y=y, sr=sr)
31
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
32
+ dur = librosa.get_duration(y=y, sr=sr)
33
+ total_frames = log_mel_spec.shape[1]
34
+ step = int(width * total_frames / dur)
35
+ count = int(total_frames / step)
36
+ begin = int(0.5 * (total_frames - count * step))
37
+ end = begin + step * count
38
+ for i in range(begin, end, step):
39
+ librosa.display.specshow(log_mel_spec[:, i:i + step])
40
+ plt.axis('off')
41
+ plt.savefig(
42
+ f'./tmp/mel_{round(dur, 2)}_{i}.jpg',
43
+ bbox_inches='tight',
44
+ pad_inches=0.0
45
+ )
46
+ plt.close()
47
+
48
+ except Exception as e:
49
+ print(f'Error converting {audio_path} : {e}')
50
+
51
+
52
+ def wav_to_cqt(audio_path: str, width=0.07):
53
+ create_dir('./tmp')
54
+ try:
55
+ y, sr = librosa.load(audio_path, sr=48000)
56
+ cqt_spec = librosa.cqt(y=y, sr=sr)
57
+ log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec)**2, ref=np.max)
58
+ dur = librosa.get_duration(y=y, sr=sr)
59
+ total_frames = log_cqt_spec.shape[1]
60
+ step = int(width * total_frames / dur)
61
+ count = int(total_frames / step)
62
+ begin = int(0.5 * (total_frames - count * step))
63
+ end = begin + step * count
64
+ for i in range(begin, end, step):
65
+ librosa.display.specshow(log_cqt_spec[:, i:i + step])
66
+ plt.axis('off')
67
+ plt.savefig(
68
+ f'./tmp/cqt_{round(dur, 2)}_{i}.jpg',
69
+ bbox_inches='tight',
70
+ pad_inches=0.0
71
+ )
72
+ plt.close()
73
+
74
+ except Exception as e:
75
+ print(f'Error converting {audio_path} : {e}')
76
+
77
+
78
+ def wav_to_chroma(audio_path: str, width=0.07):
79
+ create_dir('./tmp')
80
+ try:
81
+ y, sr = librosa.load(audio_path, sr=48000)
82
+ chroma_spec = librosa.feature.chroma_stft(y=y, sr=sr)
83
+ log_chroma_spec = librosa.power_to_db(
84
+ np.abs(chroma_spec)**2,
85
+ ref=np.max
86
+ )
87
+ dur = librosa.get_duration(y=y, sr=sr)
88
+ total_frames = log_chroma_spec.shape[1]
89
+ step = int(width * total_frames / dur)
90
+ count = int(total_frames / step)
91
+ begin = int(0.5 * (total_frames - count * step))
92
+ end = begin + step * count
93
+ for i in range(begin, end, step):
94
+ librosa.display.specshow(log_chroma_spec[:, i:i + step])
95
+ plt.axis('off')
96
+ plt.savefig(
97
+ f'./tmp/chroma_{round(dur, 2)}_{i}.jpg',
98
+ bbox_inches='tight',
99
+ pad_inches=0.0
100
+ )
101
+ plt.close()
102
+
103
+ except Exception as e:
104
+ print(f'Error converting {audio_path} : {e}')
105
+
106
+
107
+ def embed_img(img_path, input_size=224):
108
+ transform = transforms.Compose([
109
+ transforms.Resize([input_size, input_size]),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
112
+ ])
113
+ img = Image.open(img_path).convert("RGB")
114
+ return transform(img).unsqueeze(0)
115
+
116
+
117
+ def inference(wav_path, log_name, folder_path='./tmp'):
118
+ if os.path.exists(folder_path):
119
+ shutil.rmtree(folder_path)
120
+
121
+ if not wav_path:
122
+ wav_path = './examples/m_chest.wav'
123
+
124
+ model = EvalNet(log_name).model
125
+ spec = log_name.split('_')[-3]
126
+ eval('wav_to_%s' % spec)(wav_path)
127
+ outputs = []
128
+ all_files = os.listdir(folder_path)
129
+ for file_name in all_files:
130
+ if file_name.lower().endswith('.jpg'):
131
+ file_path = os.path.join(folder_path, file_name)
132
+ input = embed_img(file_path)
133
+ output = model(input)
134
+ pred_id = torch.max(output.data, 1)[1]
135
+ outputs.append(pred_id)
136
+
137
+ max_count_item = most_common_element(outputs)
138
+ shutil.rmtree(folder_path)
139
+ return translate[classes[max_count_item]]
140
+
141
+
142
+ classes = ['m_chest', 'f_chest', 'm_falsetto', 'f_falsetto']
143
+
144
+ models = [
145
+ 'squeezenet1_1_cqt_2023-12-21_14-40-13'
146
+ ]
147
+
148
+ translate = {
149
+ 'm_chest': 'male chest voice',
150
+ 'f_chest': 'female chest voice',
151
+ 'm_falsetto': 'male falsetto voice',
152
+ 'f_falsetto': 'female falsetto voice'
153
+ }
154
+
155
+ examples = []
156
+ example_wavs = find_wav_files()
157
+ for wav in example_wavs:
158
+ examples.append([
159
+ wav,
160
+ models[0]
161
+ ])
162
+
163
+ iface = gr.Interface(
164
+ fn=inference,
165
+ inputs=[
166
+ gr.Audio(label='上传录音', type='filepath'),
167
+ gr.Dropdown(
168
+ choices=models,
169
+ label='选择模型',
170
+ value=models[0]
171
+ )
172
+ ],
173
+ outputs=gr.Textbox(label='真假声识别'),
174
+ examples=examples
175
+ )
176
+
177
+ iface.launch()
examples/f_chest.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8dddb0301cca48ec0572d39c31c73e16e8073fde1b437bc6f9cd24e4b9db8ad
3
+ size 22054
examples/f_falsetto.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a58f18110cb1d93531a978ae39e75850a45151b061392dc41f2abc7853ae54b
3
+ size 22088
examples/m_chest.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45cea6c7d58cfb8effb622db74f93ed90e166634f40cead5be4c4a491ae987ed
3
+ size 22086
examples/m_falsetto.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b049990994af0f49f1170c285a0df902cbb4e4b10af92a8f805bd74f4a6584a
3
+ size 21966
model.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ from modelscope.msdatasets import MsDataset
6
+ from utils import url_download, create_dir, DOMAIN
7
+
8
+
9
+ def get_backbone(ver, backbone_list):
10
+ for bb in backbone_list:
11
+ if ver == bb['ver']:
12
+ return bb
13
+
14
+ print('Backbone name not found, using default option - alexnet.')
15
+ return backbone_list[0]
16
+
17
+
18
+ def model_info(m_ver):
19
+ backbone_list = MsDataset.load(
20
+ 'monetjoe/cv_backbones',
21
+ subset_name='ImageNet1k_v1',
22
+ split='train'
23
+ )
24
+ backbone = get_backbone(m_ver, backbone_list)
25
+ m_type = str(backbone['type'])
26
+ input_size = int(backbone['input_size'])
27
+ return m_type, input_size
28
+
29
+
30
+ def download_model(log_name='squeezenet1_1_cqt_2023-12-21_14-40-13'):
31
+ pre_model_url = f'{DOMAIN}{log_name}/save.pt'
32
+ pre_model_path = f"./model/{log_name}.pt"
33
+ m_ver = '_'.join(log_name.split('_')[:-3])
34
+ create_dir('./model')
35
+
36
+ if not os.path.exists(pre_model_path):
37
+ url_download(pre_model_url, pre_model_path)
38
+
39
+ return pre_model_path, m_ver
40
+
41
+
42
+ def Classifier(cls_num: int, output_size: int, linear_output: bool):
43
+ q = (1.0 * output_size / cls_num) ** 0.25
44
+ l1 = int(q * cls_num)
45
+ l2 = int(q * l1)
46
+ l3 = int(q * l2)
47
+
48
+ if linear_output:
49
+ return torch.nn.Sequential(
50
+ nn.Dropout(),
51
+ nn.Linear(output_size, l3),
52
+ nn.ReLU(inplace=True),
53
+ nn.Dropout(),
54
+ nn.Linear(l3, l2),
55
+ nn.ReLU(inplace=True),
56
+ nn.Dropout(),
57
+ nn.Linear(l2, l1),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear(l1, cls_num)
60
+ )
61
+
62
+ else:
63
+ return torch.nn.Sequential(
64
+ nn.Dropout(),
65
+ nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
66
+ nn.ReLU(inplace=True),
67
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
68
+ nn.Flatten(),
69
+ nn.Linear(l3, l2),
70
+ nn.ReLU(inplace=True),
71
+ nn.Dropout(),
72
+ nn.Linear(l2, l1),
73
+ nn.ReLU(inplace=True),
74
+ nn.Linear(l1, cls_num)
75
+ )
76
+
77
+
78
+ class EvalNet():
79
+ model = None
80
+ m_type = 'squeezenet'
81
+ input_size = 224
82
+ output_size = 512
83
+
84
+ def __init__(self, log_name, cls_num=4):
85
+ saved_model_path, m_ver = download_model(log_name)
86
+ self.m_type, self.input_size = model_info(m_ver)
87
+
88
+ if not hasattr(models, m_ver):
89
+ print('Unsupported model.')
90
+ exit()
91
+
92
+ self.model = eval('models.%s()' % m_ver)
93
+ linear_output = self._set_outsize()
94
+ self._set_classifier(cls_num, linear_output)
95
+ checkpoint = torch.load(saved_model_path, map_location='cpu')
96
+ if torch.cuda.is_available():
97
+ checkpoint = torch.load(saved_model_path)
98
+
99
+ self.model.load_state_dict(checkpoint, False)
100
+ self.model.eval()
101
+
102
+ def _set_outsize(self, debug_mode=False):
103
+ for name, module in self.model.named_modules():
104
+ if str(name).__contains__('classifier') or str(name).__eq__('fc') or str(name).__contains__('head'):
105
+ if isinstance(module, torch.nn.Linear):
106
+ self.output_size = module.in_features
107
+ if debug_mode:
108
+ print(
109
+ f"{name}(Linear): {self.output_size} -> {module.out_features}")
110
+ return True
111
+
112
+ if isinstance(module, torch.nn.Conv2d):
113
+ self.output_size = module.in_channels
114
+ if debug_mode:
115
+ print(
116
+ f"{name}(Conv2d): {self.output_size} -> {module.out_channels}")
117
+ return False
118
+
119
+ return False
120
+
121
+ def _set_classifier(self, cls_num, linear_output):
122
+ if hasattr(self.model, 'classifier'):
123
+ self.model.classifier = Classifier(
124
+ cls_num, self.output_size, linear_output)
125
+ return
126
+
127
+ elif hasattr(self.model, 'fc'):
128
+ self.model.fc = Classifier(
129
+ cls_num, self.output_size, linear_output)
130
+ return
131
+
132
+ elif hasattr(self.model, 'head'):
133
+ self.model.head = Classifier(
134
+ cls_num, self.output_size, linear_output)
135
+ return
136
+
137
+ self.model.heads.head = Classifier(
138
+ cls_num, self.output_size, linear_output)
139
+
140
+ def forward(self, x):
141
+ if torch.cuda.is_available():
142
+ x = x.cuda()
143
+ self.model = self.model.cuda()
144
+
145
+ if self.m_type == 'googlenet' and self.training:
146
+ return self.model(x)[0]
147
+ else:
148
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ librosa
2
+ torch
3
+ matplotlib
4
+ torchvision
5
+ pillow
6
+ gradio
7
+ modelscope
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import zipfile
5
+ import requests
6
+ from tqdm import tqdm
7
+
8
+ DOMAIN = 'https://www.modelscope.cn/api/v1/models/ccmusic/chest_falsetto/repo?Revision=master&FilePath='
9
+
10
+
11
+ def create_dir(dir):
12
+ if not os.path.exists(dir):
13
+ os.mkdir(dir)
14
+
15
+
16
+ def url_download(url: str, fname: str, max_retries=3):
17
+ retry_count = 0
18
+ while retry_count < max_retries:
19
+ try:
20
+ print(f"Downloading: {url}")
21
+ resp = requests.get(url, stream=True)
22
+ # Check the response status code (raise an exception if it's not in the range 200-299)
23
+ resp.raise_for_status()
24
+ total = int(resp.headers.get('content-length', 0))
25
+ # create_dir(data_dir)
26
+ with open(fname, 'wb') as file, tqdm(
27
+ desc=fname,
28
+ total=total,
29
+ unit='iB',
30
+ unit_scale=True,
31
+ unit_divisor=1024,
32
+ ) as bar:
33
+ for data in resp.iter_content(chunk_size=1024):
34
+ size = file.write(data)
35
+ bar.update(size)
36
+ print(f'Download of {url} completed.')
37
+ return
38
+
39
+ except requests.exceptions.HTTPError as errh:
40
+ print(f"HTTP error occurred: {errh}")
41
+ retry_count += 1
42
+ continue
43
+ except requests.exceptions.ConnectionError as errc:
44
+ print(f"Connection error occurred: {errc}")
45
+ retry_count += 1
46
+ continue
47
+ except requests.exceptions.Timeout as errt:
48
+ print(f"Timeout error occurred: {errt}")
49
+ retry_count += 1
50
+ continue
51
+ except Exception as err:
52
+ print(f"Other error occurred: {err}")
53
+ retry_count += 1
54
+ continue
55
+
56
+ else:
57
+ print(
58
+ "Error: the operation could not be completed after {max_retries} retries."
59
+ )
60
+ exit()
61
+
62
+
63
+ def unzip_file(zip_src, dst_dir):
64
+ r = zipfile.is_zipfile(zip_src)
65
+ if r:
66
+ fz = zipfile.ZipFile(zip_src, 'r')
67
+ for file in fz.namelist():
68
+ fz.extract(file, dst_dir)
69
+ else:
70
+ print('This is not zip')
71
+
72
+
73
+ def time_stamp(timestamp=None):
74
+ if timestamp != None:
75
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
76
+
77
+ return time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
78
+
79
+
80
+ def toCUDA(x):
81
+ if hasattr(x, 'cuda'):
82
+ if torch.cuda.is_available():
83
+ return x.cuda()
84
+
85
+ return x
86
+
87
+
88
+ def find_wav_files(folder_path='./examples'):
89
+ wav_files = []
90
+ for root, _, files in os.walk(folder_path):
91
+ for file in files:
92
+ if file.endswith(".wav"):
93
+ file_path = os.path.join(root, file)
94
+ wav_files.append(file_path)
95
+
96
+ return wav_files