smjain commited on
Commit
552a5de
·
verified ·
1 Parent(s): 1b83a0a

Upload preprocess.py

Browse files
Files changed (1) hide show
  1. preprocess.py +147 -0
preprocess.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import os
3
+ import sys
4
+
5
+ from scipy import signal
6
+
7
+ now_dir = os.getcwd()
8
+ sys.path.append(now_dir)
9
+ print(sys.argv)
10
+ inp_root = sys.argv[1]
11
+ sr = int(sys.argv[2])
12
+ n_p = int(sys.argv[3])
13
+ exp_dir = sys.argv[4]
14
+ noparallel = sys.argv[5] == "True"
15
+ per = float(sys.argv[6])
16
+ import multiprocessing
17
+ import os
18
+ import traceback
19
+
20
+ import librosa
21
+ import numpy as np
22
+ from scipy.io import wavfile
23
+
24
+ from infer.lib.audio import load_audio
25
+ from infer.lib.slicer2 import Slicer
26
+
27
+ mutex = multiprocessing.Lock()
28
+ f = open("%s/preprocess.log" % exp_dir, "a+")
29
+
30
+
31
+ def println(strr):
32
+ mutex.acquire()
33
+ print(strr)
34
+ f.write("%s\n" % strr)
35
+ f.flush()
36
+ mutex.release()
37
+
38
+
39
+ class PreProcess:
40
+ def __init__(self, sr, exp_dir, per=3.0):
41
+ self.slicer = Slicer(
42
+ sr=sr,
43
+ threshold=-42,
44
+ min_length=1500,
45
+ min_interval=400,
46
+ hop_size=15,
47
+ max_sil_kept=500,
48
+ )
49
+ self.sr = sr
50
+ self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr)
51
+ self.per = per
52
+ self.overlap = 0.3
53
+ self.tail = self.per + self.overlap
54
+ self.max = 0.9
55
+ self.alpha = 0.75
56
+ self.exp_dir = exp_dir
57
+ self.gt_wavs_dir = "%s/0_gt_wavs" % exp_dir
58
+ self.wavs16k_dir = "%s/1_16k_wavs" % exp_dir
59
+ os.makedirs(self.exp_dir, exist_ok=True)
60
+ os.makedirs(self.gt_wavs_dir, exist_ok=True)
61
+ os.makedirs(self.wavs16k_dir, exist_ok=True)
62
+
63
+ def norm_write(self, tmp_audio, idx0, idx1):
64
+ tmp_max = np.abs(tmp_audio).max()
65
+ if tmp_max > 2.5:
66
+ print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
67
+ return
68
+ tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
69
+ 1 - self.alpha
70
+ ) * tmp_audio
71
+ wavfile.write(
72
+ "%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
73
+ self.sr,
74
+ tmp_audio.astype(np.float32),
75
+ )
76
+ tmp_audio = librosa.resample(
77
+ tmp_audio, orig_sr=self.sr, target_sr=16000
78
+ ) # , res_type="soxr_vhq"
79
+ wavfile.write(
80
+ "%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
81
+ 16000,
82
+ tmp_audio.astype(np.float32),
83
+ )
84
+
85
+ def pipeline(self, path, idx0):
86
+ try:
87
+ audio = load_audio(path, self.sr)
88
+ # zero phased digital filter cause pre-ringing noise...
89
+ # audio = signal.filtfilt(self.bh, self.ah, audio)
90
+ audio = signal.lfilter(self.bh, self.ah, audio)
91
+
92
+ idx1 = 0
93
+ for audio in self.slicer.slice(audio):
94
+ i = 0
95
+ while 1:
96
+ start = int(self.sr * (self.per - self.overlap) * i)
97
+ i += 1
98
+ if len(audio[start:]) > self.tail * self.sr:
99
+ tmp_audio = audio[start : start + int(self.per * self.sr)]
100
+ self.norm_write(tmp_audio, idx0, idx1)
101
+ idx1 += 1
102
+ else:
103
+ tmp_audio = audio[start:]
104
+ idx1 += 1
105
+ break
106
+ self.norm_write(tmp_audio, idx0, idx1)
107
+ println("%s->Suc." % path)
108
+ except:
109
+ println("%s->%s" % (path, traceback.format_exc()))
110
+
111
+ def pipeline_mp(self, infos):
112
+ for path, idx0 in infos:
113
+ self.pipeline(path, idx0)
114
+
115
+ def pipeline_mp_inp_dir(self, inp_root, n_p):
116
+ try:
117
+ infos = [
118
+ ("%s/%s" % (inp_root, name), idx)
119
+ for idx, name in enumerate(sorted(list(os.listdir(inp_root))))
120
+ ]
121
+ if noparallel:
122
+ for i in range(n_p):
123
+ self.pipeline_mp(infos[i::n_p])
124
+ else:
125
+ ps = []
126
+ for i in range(n_p):
127
+ p = multiprocessing.Process(
128
+ target=self.pipeline_mp, args=(infos[i::n_p],)
129
+ )
130
+ ps.append(p)
131
+ p.start()
132
+ for i in range(n_p):
133
+ ps[i].join()
134
+ except:
135
+ println("Fail. %s" % traceback.format_exc())
136
+
137
+
138
+ def preprocess_trainset(inp_root, sr, n_p, exp_dir, per):
139
+ pp = PreProcess(sr, exp_dir, per)
140
+ println("start preprocess")
141
+ println(sys.argv)
142
+ pp.pipeline_mp_inp_dir(inp_root, n_p)
143
+ println("end preprocess")
144
+
145
+
146
+ if __name__ == "__main__":
147
+ preprocess_trainset(inp_root, sr, n_p, exp_dir, per)