smjain commited on
Commit
01fa2a4
1 Parent(s): ef5c9b8

Upload trainset_preprocess_pipeline_print.py

Browse files
Files changed (1) hide show
  1. trainset_preprocess_pipeline_print.py +135 -0
trainset_preprocess_pipeline_print.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, multiprocessing
2
+ from scipy import signal
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+
7
+ inp_root = sys.argv[1]
8
+ sr = int(sys.argv[2])
9
+ n_p = int(sys.argv[3])
10
+ exp_dir = sys.argv[4]
11
+ noparallel = sys.argv[5] == "True"
12
+ import numpy as np, os, traceback
13
+ from slicer2 import Slicer
14
+ import librosa, traceback
15
+ from scipy.io import wavfile
16
+ import multiprocessing
17
+ from my_utils import load_audio
18
+
19
+ mutex = multiprocessing.Lock()
20
+ f = open("%s/preprocess.log" % exp_dir, "a+")
21
+
22
+
23
+ def println(strr):
24
+ mutex.acquire()
25
+ print(strr)
26
+ f.write("%s\n" % strr)
27
+ f.flush()
28
+ mutex.release()
29
+
30
+
31
+ class PreProcess:
32
+ def __init__(self, sr, exp_dir):
33
+ self.slicer = Slicer(
34
+ sr=sr,
35
+ threshold=-40,
36
+ min_length=800,
37
+ min_interval=400,
38
+ hop_size=15,
39
+ max_sil_kept=150,
40
+ )
41
+ self.sr = sr
42
+ self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr)
43
+ self.per = 3.7
44
+ self.overlap = 0.3
45
+ self.tail = self.per + self.overlap
46
+ self.max = 0.95
47
+ self.alpha = 0.8
48
+ self.exp_dir = exp_dir
49
+ self.gt_wavs_dir = "%s/0_gt_wavs" % exp_dir
50
+ self.wavs16k_dir = "%s/1_16k_wavs" % exp_dir
51
+ os.makedirs(self.exp_dir, exist_ok=True)
52
+ os.makedirs(self.gt_wavs_dir, exist_ok=True)
53
+ os.makedirs(self.wavs16k_dir, exist_ok=True)
54
+
55
+ def norm_write(self, tmp_audio, idx0, idx1):
56
+ tmp_audio = (tmp_audio / np.abs(tmp_audio).max() * (self.max * self.alpha)) + (
57
+ 1 - self.alpha
58
+ ) * tmp_audio
59
+ wavfile.write(
60
+ "%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
61
+ self.sr,
62
+ tmp_audio.astype(np.float32),
63
+ )
64
+ tmp_audio = librosa.resample(
65
+ tmp_audio, orig_sr=self.sr, target_sr=16000
66
+ ) # , res_type="soxr_vhq"
67
+ wavfile.write(
68
+ "%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
69
+ 16000,
70
+ tmp_audio.astype(np.float32),
71
+ )
72
+
73
+ def pipeline(self, path, idx0):
74
+ try:
75
+ audio = load_audio(path, self.sr)
76
+ # zero phased digital filter cause pre-ringing noise...
77
+ # audio = signal.filtfilt(self.bh, self.ah, audio)
78
+ audio = signal.lfilter(self.bh, self.ah, audio)
79
+
80
+ idx1 = 0
81
+ for audio in self.slicer.slice(audio):
82
+ i = 0
83
+ while 1:
84
+ start = int(self.sr * (self.per - self.overlap) * i)
85
+ i += 1
86
+ if len(audio[start:]) > self.tail * self.sr:
87
+ tmp_audio = audio[start : start + int(self.per * self.sr)]
88
+ self.norm_write(tmp_audio, idx0, idx1)
89
+ idx1 += 1
90
+ else:
91
+ tmp_audio = audio[start:]
92
+ idx1 += 1
93
+ break
94
+ self.norm_write(tmp_audio, idx0, idx1)
95
+ println("%s->Suc." % path)
96
+ except:
97
+ println("%s->%s" % (path, traceback.format_exc()))
98
+
99
+ def pipeline_mp(self, infos):
100
+ for path, idx0 in infos:
101
+ self.pipeline(path, idx0)
102
+
103
+ def pipeline_mp_inp_dir(self, inp_root, n_p):
104
+ try:
105
+ infos = [
106
+ ("%s/%s" % (inp_root, name), idx)
107
+ for idx, name in enumerate(sorted(list(os.listdir(inp_root))))
108
+ ]
109
+ if noparallel:
110
+ for i in range(n_p):
111
+ self.pipeline_mp(infos[i::n_p])
112
+ else:
113
+ ps = []
114
+ for i in range(n_p):
115
+ p = multiprocessing.Process(
116
+ target=self.pipeline_mp, args=(infos[i::n_p],)
117
+ )
118
+ p.start()
119
+ ps.append(p)
120
+ for p in ps:
121
+ p.join()
122
+ except:
123
+ println("Fail. %s" % traceback.format_exc())
124
+
125
+
126
+ def preprocess_trainset(inp_root, sr, n_p, exp_dir):
127
+ pp = PreProcess(sr, exp_dir)
128
+ println("start preprocess")
129
+ println(sys.argv)
130
+ pp.pipeline_mp_inp_dir(inp_root, n_p)
131
+ println("end preprocess")
132
+
133
+
134
+ if __name__ == "__main__":
135
+ preprocess_trainset(inp_root, sr, n_p, exp_dir)