MuGeminorum commited on
Commit
f945864
1 Parent(s): dea9f72
Files changed (10) hide show
  1. .gitattributes +11 -11
  2. .gitignore +5 -0
  3. app.py +190 -0
  4. examples/f_bel.wav +3 -0
  5. examples/f_folk.wav +3 -0
  6. examples/m_bel.wav +3 -0
  7. examples/m_folk.wav +3 -0
  8. model.py +148 -0
  9. requirements.txt +6 -0
  10. utils.py +96 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.pt
2
+ __pycache__/*
3
+ tmp/*
4
+ flagged/*
5
+ test.py
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import shutil
4
+ import librosa
5
+ import numpy as np
6
+ import gradio as gr
7
+ import librosa.display
8
+ import matplotlib.pyplot as plt
9
+ import torchvision.transforms as transforms
10
+ from collections import Counter
11
+ from model import EvalNet
12
+ from PIL import Image
13
+ from utils import *
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ classes = ['m_bel', 'f_bel', 'm_folk', 'f_folk']
18
+
19
+
20
+ def most_common_element(input_list):
21
+ # 使用 Counter 统计每个元素的出现次数
22
+ counter = Counter(input_list)
23
+ # 使用 most_common 方法获取出现次数最多的元素
24
+ most_common_element, _ = counter.most_common(1)[0]
25
+ return most_common_element
26
+
27
+
28
+ def wav_to_mel(audio_path: str, width=1.6, topdb=40):
29
+ create_dir('./tmp')
30
+ try:
31
+ y, sr = librosa.load(audio_path, sr=48000)
32
+ non_silents = librosa.effects.split(y, top_db=topdb)
33
+ non_silent = np.concatenate(
34
+ [y[start:end] for start, end in non_silents]
35
+ )
36
+ mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
37
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
38
+ dur = librosa.get_duration(y=non_silent, sr=sr)
39
+ total_frames = log_mel_spec.shape[1]
40
+ step = int(width * total_frames / dur)
41
+ count = int(total_frames / step)
42
+ begin = int(0.5 * (total_frames - count * step))
43
+ end = begin + step * count
44
+ for i in range(begin, end, step):
45
+ librosa.display.specshow(log_mel_spec[:, i:i + step])
46
+ plt.axis('off')
47
+ plt.savefig(
48
+ f'./tmp/mel_{round(dur, 2)}_{i}.jpg',
49
+ bbox_inches='tight',
50
+ pad_inches=0.0
51
+ )
52
+ plt.close()
53
+
54
+ except Exception as e:
55
+ print(f'Error converting {audio_path} : {e}')
56
+
57
+
58
+ def wav_to_cqt(audio_path: str, width=1.6, topdb=40):
59
+ create_dir('./tmp')
60
+ try:
61
+ y, sr = librosa.load(audio_path, sr=48000)
62
+ non_silents = librosa.effects.split(y, top_db=topdb)
63
+ non_silent = np.concatenate(
64
+ [y[start:end] for start, end in non_silents]
65
+ )
66
+ cqt_spec = librosa.cqt(y=non_silent, sr=sr)
67
+ log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec)**2, ref=np.max)
68
+ dur = librosa.get_duration(y=non_silent, sr=sr)
69
+ total_frames = log_cqt_spec.shape[1]
70
+ step = int(width * total_frames / dur)
71
+ count = int(total_frames / step)
72
+ begin = int(0.5 * (total_frames - count * step))
73
+ end = begin + step * count
74
+ for i in range(begin, end, step):
75
+ librosa.display.specshow(log_cqt_spec[:, i:i + step])
76
+ plt.axis('off')
77
+ plt.savefig(
78
+ f'./tmp/cqt_{round(dur, 2)}_{i}.jpg',
79
+ bbox_inches='tight',
80
+ pad_inches=0.0
81
+ )
82
+ plt.close()
83
+
84
+ except Exception as e:
85
+ print(f'Error converting {audio_path} : {e}')
86
+
87
+
88
+ def wav_to_chroma(audio_path: str, width=1.6, topdb=40):
89
+ create_dir('./tmp')
90
+ try:
91
+ y, sr = librosa.load(audio_path, sr=48000)
92
+ non_silents = librosa.effects.split(y, top_db=topdb)
93
+ non_silent = np.concatenate(
94
+ [y[start:end] for start, end in non_silents]
95
+ )
96
+ chroma_spec = librosa.feature.chroma_stft(y=non_silent, sr=sr)
97
+ log_chroma_spec = librosa.power_to_db(
98
+ np.abs(chroma_spec)**2,
99
+ ref=np.max
100
+ )
101
+ dur = librosa.get_duration(y=non_silent, sr=sr)
102
+ total_frames = log_chroma_spec.shape[1]
103
+ step = int(width * total_frames / dur)
104
+ count = int(total_frames / step)
105
+ begin = int(0.5 * (total_frames - count * step))
106
+ end = begin + step * count
107
+ for i in range(begin, end, step):
108
+ librosa.display.specshow(log_chroma_spec[:, i:i + step])
109
+ plt.axis('off')
110
+ plt.savefig(
111
+ f'./tmp/chroma_{round(dur, 2)}_{i}.jpg',
112
+ bbox_inches='tight',
113
+ pad_inches=0.0
114
+ )
115
+ plt.close()
116
+
117
+ except Exception as e:
118
+ print(f'Error converting {audio_path} : {e}')
119
+
120
+
121
+ def embed_img(img_path, input_size=224):
122
+ transform = transforms.Compose([
123
+ transforms.Resize([input_size, input_size]),
124
+ transforms.ToTensor(),
125
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
126
+ ])
127
+ img = Image.open(img_path).convert("RGB")
128
+ return transform(img).unsqueeze(0)
129
+
130
+
131
+ def inference(wav_path, log_name, folder_path='./tmp'):
132
+ if os.path.exists(folder_path):
133
+ shutil.rmtree(folder_path)
134
+
135
+ if not wav_path:
136
+ wav_path = './examples/f_bel.wav'
137
+
138
+ model = EvalNet(log_name).model
139
+ spec = log_name.split('_')[-3]
140
+ eval('wav_to_%s' % spec)(wav_path)
141
+ outputs = []
142
+ all_files = os.listdir(folder_path)
143
+ for file_name in all_files:
144
+ if file_name.lower().endswith('.jpg'):
145
+ file_path = os.path.join(folder_path, file_name)
146
+ input = embed_img(file_path)
147
+ output = model(input)
148
+ pred_id = torch.max(output.data, 1)[1]
149
+ outputs.append(pred_id)
150
+
151
+ max_count_item = most_common_element(outputs)
152
+ shutil.rmtree(folder_path)
153
+ return translate[classes[max_count_item]]
154
+
155
+
156
+ models = [
157
+ 'vit_b_16_mel_2024-01-07_05-16-24',
158
+ 'swin_b_chroma_2024-01-07_14-01-10'
159
+ ]
160
+
161
+ translate = {
162
+ 'm_bel': 'male bel canto',
163
+ 'm_folk': 'male folk singing',
164
+ 'f_bel': 'female bel canto',
165
+ 'f_folk': 'female folk singing'
166
+ }
167
+
168
+ examples = []
169
+ example_wavs = find_wav_files()
170
+ for wav in example_wavs:
171
+ examples.append([
172
+ wav,
173
+ models[0]
174
+ ])
175
+
176
+ iface = gr.Interface(
177
+ fn=inference,
178
+ inputs=[
179
+ gr.Audio(label='Upload audio', type='filepath'),
180
+ gr.Dropdown(
181
+ choices=models,
182
+ label='Select model',
183
+ value=models[0]
184
+ )
185
+ ],
186
+ outputs=gr.Textbox(label='Singing method'),
187
+ examples=examples
188
+ )
189
+
190
+ iface.launch()
examples/f_bel.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26abdaf26e98f1ac58a510462740ca47a569b4060917e2f413cd4a84aa0d8b66
3
+ size 839708
examples/f_folk.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:752c041e9c44762a90b5f0983cda805bcdc09d308d564574d6146c2bfdca2d97
3
+ size 1183688
examples/m_bel.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7b1aa8cfc6e004df1d1a7649927c06187535ce8531f3dda2177709b9d11b70d
3
+ size 2881538
examples/m_folk.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51c3b595ae7c0a361a6364df282439aa923a1098c9b62abfa13b6e82558a10c5
3
+ size 1154582
model.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ from modelscope.msdatasets import MsDataset
6
+ from utils import url_download, create_dir, DOMAIN
7
+
8
+
9
+ def get_backbone(ver, backbone_list):
10
+ for bb in backbone_list:
11
+ if ver == bb['ver']:
12
+ return bb
13
+
14
+ print('Backbone name not found, using default option - alexnet.')
15
+ return backbone_list[0]
16
+
17
+
18
+ def model_info(m_ver):
19
+ backbone_list = MsDataset.load(
20
+ 'monetjoe/cv_backbones',
21
+ subset_name='ImageNet1k_v1',
22
+ split='train'
23
+ )
24
+ backbone = get_backbone(m_ver, backbone_list)
25
+ m_type = str(backbone['type'])
26
+ input_size = int(backbone['input_size'])
27
+ return m_type, input_size
28
+
29
+
30
+ def download_model(log_name='vit_b_16_mel_2024-01-07_05-16-24'):
31
+ pre_model_url = f'{DOMAIN}{log_name}/save.pt'
32
+ pre_model_path = f"./model/{log_name}.pt"
33
+ m_ver = '_'.join(log_name.split('_')[:-3])
34
+ create_dir('./model')
35
+
36
+ if not os.path.exists(pre_model_path):
37
+ url_download(pre_model_url, pre_model_path)
38
+
39
+ return pre_model_path, m_ver
40
+
41
+
42
+ def Classifier(cls_num: int, output_size: int, linear_output: bool):
43
+ q = (1.0 * output_size / cls_num) ** 0.25
44
+ l1 = int(q * cls_num)
45
+ l2 = int(q * l1)
46
+ l3 = int(q * l2)
47
+
48
+ if linear_output:
49
+ return torch.nn.Sequential(
50
+ nn.Dropout(),
51
+ nn.Linear(output_size, l3),
52
+ nn.ReLU(inplace=True),
53
+ nn.Dropout(),
54
+ nn.Linear(l3, l2),
55
+ nn.ReLU(inplace=True),
56
+ nn.Dropout(),
57
+ nn.Linear(l2, l1),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear(l1, cls_num)
60
+ )
61
+
62
+ else:
63
+ return torch.nn.Sequential(
64
+ nn.Dropout(),
65
+ nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
66
+ nn.ReLU(inplace=True),
67
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
68
+ nn.Flatten(),
69
+ nn.Linear(l3, l2),
70
+ nn.ReLU(inplace=True),
71
+ nn.Dropout(),
72
+ nn.Linear(l2, l1),
73
+ nn.ReLU(inplace=True),
74
+ nn.Linear(l1, cls_num)
75
+ )
76
+
77
+
78
+ class EvalNet():
79
+ model = None
80
+ m_type = 'squeezenet'
81
+ input_size = 224
82
+ output_size = 512
83
+
84
+ def __init__(self, log_name, cls_num=4):
85
+ saved_model_path, m_ver = download_model(log_name)
86
+ self.m_type, self.input_size = model_info(m_ver)
87
+
88
+ if not hasattr(models, m_ver):
89
+ print('Unsupported model.')
90
+ exit()
91
+
92
+ self.model = eval('models.%s()' % m_ver)
93
+ linear_output = self._set_outsize()
94
+ self._set_classifier(cls_num, linear_output)
95
+ checkpoint = torch.load(saved_model_path, map_location='cpu')
96
+ if torch.cuda.is_available():
97
+ checkpoint = torch.load(saved_model_path)
98
+
99
+ self.model.load_state_dict(checkpoint, False)
100
+ self.model.eval()
101
+
102
+ def _set_outsize(self, debug_mode=False):
103
+ for name, module in self.model.named_modules():
104
+ if str(name).__contains__('classifier') or str(name).__eq__('fc') or str(name).__contains__('head'):
105
+ if isinstance(module, torch.nn.Linear):
106
+ self.output_size = module.in_features
107
+ if debug_mode:
108
+ print(
109
+ f"{name}(Linear): {self.output_size} -> {module.out_features}")
110
+ return True
111
+
112
+ if isinstance(module, torch.nn.Conv2d):
113
+ self.output_size = module.in_channels
114
+ if debug_mode:
115
+ print(
116
+ f"{name}(Conv2d): {self.output_size} -> {module.out_channels}")
117
+ return False
118
+
119
+ return False
120
+
121
+ def _set_classifier(self, cls_num, linear_output):
122
+ if hasattr(self.model, 'classifier'):
123
+ self.model.classifier = Classifier(
124
+ cls_num, self.output_size, linear_output)
125
+ return
126
+
127
+ elif hasattr(self.model, 'fc'):
128
+ self.model.fc = Classifier(
129
+ cls_num, self.output_size, linear_output)
130
+ return
131
+
132
+ elif hasattr(self.model, 'head'):
133
+ self.model.head = Classifier(
134
+ cls_num, self.output_size, linear_output)
135
+ return
136
+
137
+ self.model.heads.head = Classifier(
138
+ cls_num, self.output_size, linear_output)
139
+
140
+ def forward(self, x):
141
+ if torch.cuda.is_available():
142
+ x = x.cuda()
143
+ self.model = self.model.cuda()
144
+
145
+ if self.m_type == 'googlenet' and self.training:
146
+ return self.model(x)[0]
147
+ else:
148
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ librosa
2
+ torch
3
+ matplotlib
4
+ torchvision
5
+ pillow
6
+ gradio
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import zipfile
5
+ import requests
6
+ from tqdm import tqdm
7
+
8
+ DOMAIN = 'https://huggingface.co/ccmusic-database/bel_canto/resolve/main/'
9
+
10
+
11
+ def create_dir(dir):
12
+ if not os.path.exists(dir):
13
+ os.mkdir(dir)
14
+
15
+
16
+ def url_download(url: str, fname: str, max_retries=3):
17
+ retry_count = 0
18
+ while retry_count < max_retries:
19
+ try:
20
+ print(f"Downloading: {url}")
21
+ resp = requests.get(url, stream=True)
22
+ # Check the response status code (raise an exception if it's not in the range 200-299)
23
+ resp.raise_for_status()
24
+ total = int(resp.headers.get('content-length', 0))
25
+ # create_dir(data_dir)
26
+ with open(fname, 'wb') as file, tqdm(
27
+ desc=fname,
28
+ total=total,
29
+ unit='iB',
30
+ unit_scale=True,
31
+ unit_divisor=1024,
32
+ ) as bar:
33
+ for data in resp.iter_content(chunk_size=1024):
34
+ size = file.write(data)
35
+ bar.update(size)
36
+ print(f'Download of {url} completed.')
37
+ return
38
+
39
+ except requests.exceptions.HTTPError as errh:
40
+ print(f"HTTP error occurred: {errh}")
41
+ retry_count += 1
42
+ continue
43
+ except requests.exceptions.ConnectionError as errc:
44
+ print(f"Connection error occurred: {errc}")
45
+ retry_count += 1
46
+ continue
47
+ except requests.exceptions.Timeout as errt:
48
+ print(f"Timeout error occurred: {errt}")
49
+ retry_count += 1
50
+ continue
51
+ except Exception as err:
52
+ print(f"Other error occurred: {err}")
53
+ retry_count += 1
54
+ continue
55
+
56
+ else:
57
+ print(
58
+ "Error: the operation could not be completed after {max_retries} retries."
59
+ )
60
+ exit()
61
+
62
+
63
+ def unzip_file(zip_src, dst_dir):
64
+ r = zipfile.is_zipfile(zip_src)
65
+ if r:
66
+ fz = zipfile.ZipFile(zip_src, 'r')
67
+ for file in fz.namelist():
68
+ fz.extract(file, dst_dir)
69
+ else:
70
+ print('This is not zip')
71
+
72
+
73
+ def time_stamp(timestamp=None):
74
+ if timestamp != None:
75
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
76
+
77
+ return time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
78
+
79
+
80
+ def toCUDA(x):
81
+ if hasattr(x, 'cuda'):
82
+ if torch.cuda.is_available():
83
+ return x.cuda()
84
+
85
+ return x
86
+
87
+
88
+ def find_wav_files(folder_path='./examples'):
89
+ wav_files = []
90
+ for root, _, files in os.walk(folder_path):
91
+ for file in files:
92
+ if file.endswith(".wav"):
93
+ file_path = os.path.join(root, file)
94
+ wav_files.append(file_path)
95
+
96
+ return wav_files