namelessai
commited on
upload main files
Browse files- __init__.py +2 -0
- __main__.py +123 -0
- lowpass.py +249 -0
- pipeline.py +175 -0
__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .utils import seed_everything, save_wave, get_time, get_duration, read_list
|
2 |
+
from .pipeline import *
|
__main__.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
9 |
+
matplotlib_logger = logging.getLogger('matplotlib')
|
10 |
+
matplotlib_logger.setLevel(logging.WARNING)
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
|
14 |
+
parser.add_argument(
|
15 |
+
"-i",
|
16 |
+
"--input_audio_file",
|
17 |
+
type=str,
|
18 |
+
required=False,
|
19 |
+
help="Input audio file for audio super resolution",
|
20 |
+
)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"-il",
|
24 |
+
"--input_file_list",
|
25 |
+
type=str,
|
26 |
+
required=False,
|
27 |
+
default="",
|
28 |
+
help="A file that contains all audio files that need to perform audio super resolution",
|
29 |
+
)
|
30 |
+
|
31 |
+
parser.add_argument(
|
32 |
+
"-s",
|
33 |
+
"--save_path",
|
34 |
+
type=str,
|
35 |
+
required=False,
|
36 |
+
help="The path to save model output",
|
37 |
+
default="./output",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
"--model_name",
|
42 |
+
type=str,
|
43 |
+
required=False,
|
44 |
+
help="The checkpoint you gonna use",
|
45 |
+
default="basic",
|
46 |
+
choices=["basic","speech"]
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"-d",
|
51 |
+
"--device",
|
52 |
+
type=str,
|
53 |
+
required=False,
|
54 |
+
help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
|
55 |
+
default="auto",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--ddim_steps",
|
60 |
+
type=int,
|
61 |
+
required=False,
|
62 |
+
default=50,
|
63 |
+
help="The sampling step for DDIM",
|
64 |
+
)
|
65 |
+
|
66 |
+
parser.add_argument(
|
67 |
+
"-gs",
|
68 |
+
"--guidance_scale",
|
69 |
+
type=float,
|
70 |
+
required=False,
|
71 |
+
default=3.5,
|
72 |
+
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
|
73 |
+
)
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
"--seed",
|
77 |
+
type=int,
|
78 |
+
required=False,
|
79 |
+
default=42,
|
80 |
+
help="Change this value (any integer number) will lead to a different generation result.",
|
81 |
+
)
|
82 |
+
|
83 |
+
parser.add_argument(
|
84 |
+
"--suffix",
|
85 |
+
type=str,
|
86 |
+
required=False,
|
87 |
+
help="Suffix for the output file",
|
88 |
+
default="_AudioSR_Processed_48K",
|
89 |
+
)
|
90 |
+
|
91 |
+
args = parser.parse_args()
|
92 |
+
torch.set_float32_matmul_precision("high")
|
93 |
+
save_path = os.path.join(args.save_path, get_time())
|
94 |
+
|
95 |
+
assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
|
96 |
+
|
97 |
+
input_file = args.input_audio_file
|
98 |
+
random_seed = args.seed
|
99 |
+
sample_rate=48000
|
100 |
+
latent_t_per_second=12.8
|
101 |
+
guidance_scale = args.guidance_scale
|
102 |
+
|
103 |
+
os.makedirs(save_path, exist_ok=True)
|
104 |
+
audiosr = build_model(model_name=args.model_name, device=args.device)
|
105 |
+
|
106 |
+
if(args.input_file_list):
|
107 |
+
print("Generate audio based on the text prompts in %s" % args.input_file_list)
|
108 |
+
files_todo = read_list(args.input_file_list)
|
109 |
+
else:
|
110 |
+
files_todo = [input_file]
|
111 |
+
|
112 |
+
for input_file in files_todo:
|
113 |
+
name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
|
114 |
+
|
115 |
+
waveform = super_resolution(
|
116 |
+
audiosr,
|
117 |
+
input_file,
|
118 |
+
seed=random_seed,
|
119 |
+
guidance_scale=guidance_scale,
|
120 |
+
ddim_steps=args.ddim_steps,
|
121 |
+
latent_t_per_second=latent_t_per_second
|
122 |
+
)
|
123 |
+
save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
|
lowpass.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.signal import butter, lfilter
|
2 |
+
import torch
|
3 |
+
from scipy import signal
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from scipy.signal import sosfiltfilt
|
8 |
+
from scipy.signal import butter, cheby1, cheby2, ellip, bessel
|
9 |
+
from scipy.signal import resample_poly
|
10 |
+
|
11 |
+
|
12 |
+
def align_length(x=None, y=None, Lx=None):
|
13 |
+
"""align the length of y to that of x
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (np.array): reference signal
|
17 |
+
y (np.array): the signal needs to be length aligned
|
18 |
+
|
19 |
+
Return:
|
20 |
+
yy (np.array): signal with the same length as x
|
21 |
+
"""
|
22 |
+
assert y is not None
|
23 |
+
|
24 |
+
if Lx is None:
|
25 |
+
Lx = len(x)
|
26 |
+
Ly = len(y)
|
27 |
+
|
28 |
+
if Lx == Ly:
|
29 |
+
return y
|
30 |
+
elif Lx > Ly:
|
31 |
+
# pad y with zeros
|
32 |
+
return np.pad(y, (0, Lx - Ly), mode="constant")
|
33 |
+
else:
|
34 |
+
# cut y
|
35 |
+
return y[:Lx]
|
36 |
+
|
37 |
+
|
38 |
+
def bandpass_filter(x, lowcut, highcut, fs, order, ftype):
|
39 |
+
"""process input signal x using bandpass filter
|
40 |
+
|
41 |
+
Args:
|
42 |
+
x (np.array): input signal
|
43 |
+
lowcut (float): low cutoff frequency
|
44 |
+
highcut (float): high cutoff frequency
|
45 |
+
order (int): the order of filter
|
46 |
+
ftype (string): type of filter
|
47 |
+
['butter', 'cheby1', 'cheby2', 'ellip', 'bessel']
|
48 |
+
|
49 |
+
Return:
|
50 |
+
y (np.array): filtered signal
|
51 |
+
"""
|
52 |
+
nyq = 0.5 * fs
|
53 |
+
lo = lowcut / nyq
|
54 |
+
hi = highcut / nyq
|
55 |
+
|
56 |
+
if ftype == "butter":
|
57 |
+
# b, a = butter(order, [lo, hi], btype='band')
|
58 |
+
sos = butter(order, [lo, hi], btype="band", output="sos")
|
59 |
+
elif ftype == "cheby1":
|
60 |
+
sos = cheby1(order, 0.1, [lo, hi], btype="band", output="sos")
|
61 |
+
elif ftype == "cheby2":
|
62 |
+
sos = cheby2(order, 60, [lo, hi], btype="band", output="sos")
|
63 |
+
elif ftype == "ellip":
|
64 |
+
sos = ellip(order, 0.1, 60, [lo, hi], btype="band", output="sos")
|
65 |
+
elif ftype == "bessel":
|
66 |
+
sos = bessel(order, [lo, hi], btype="band", output="sos")
|
67 |
+
else:
|
68 |
+
raise Exception(f"The bandpass filter {ftype} is not supported!")
|
69 |
+
|
70 |
+
# y = lfilter(b, a, x)
|
71 |
+
y = sosfiltfilt(sos, x)
|
72 |
+
|
73 |
+
if len(y) != len(x):
|
74 |
+
y = align_length(x, y)
|
75 |
+
return y
|
76 |
+
|
77 |
+
|
78 |
+
def lowpass_filter(x, highcut, fs, order, ftype):
|
79 |
+
"""process input signal x using lowpass filter
|
80 |
+
|
81 |
+
Args:
|
82 |
+
x (np.array): input signal
|
83 |
+
highcut (float): high cutoff frequency
|
84 |
+
order (int): the order of filter
|
85 |
+
ftype (string): type of filter
|
86 |
+
['butter', 'cheby1', 'cheby2', 'ellip', 'bessel']
|
87 |
+
|
88 |
+
Return:
|
89 |
+
y (np.array): filtered signal
|
90 |
+
"""
|
91 |
+
nyq = 0.5 * fs
|
92 |
+
hi = highcut / nyq
|
93 |
+
|
94 |
+
if ftype == "butter":
|
95 |
+
sos = butter(order, hi, btype="low", output="sos")
|
96 |
+
elif ftype == "cheby1":
|
97 |
+
sos = cheby1(order, 0.1, hi, btype="low", output="sos")
|
98 |
+
elif ftype == "cheby2":
|
99 |
+
sos = cheby2(order, 60, hi, btype="low", output="sos")
|
100 |
+
elif ftype == "ellip":
|
101 |
+
sos = ellip(order, 0.1, 60, hi, btype="low", output="sos")
|
102 |
+
elif ftype == "bessel":
|
103 |
+
sos = bessel(order, hi, btype="low", output="sos")
|
104 |
+
else:
|
105 |
+
raise Exception(f"The lowpass filter {ftype} is not supported!")
|
106 |
+
|
107 |
+
y = sosfiltfilt(sos, x)
|
108 |
+
|
109 |
+
if len(y) != len(x):
|
110 |
+
y = align_length(x, y)
|
111 |
+
|
112 |
+
y_len = len(y)
|
113 |
+
|
114 |
+
y = stft_hard_lowpass(y, hi, fs_ori=fs)
|
115 |
+
|
116 |
+
y = sosfiltfilt(sos, y)
|
117 |
+
|
118 |
+
if len(y) != y_len:
|
119 |
+
y = align_length(y=y, Lx=y_len)
|
120 |
+
|
121 |
+
return y
|
122 |
+
|
123 |
+
|
124 |
+
def stft_hard_lowpass(data, lowpass_ratio, fs_ori=44100):
|
125 |
+
fs_down = int(lowpass_ratio * fs_ori)
|
126 |
+
# downsample to the low sampling rate
|
127 |
+
y = resample_poly(data, fs_down, fs_ori)
|
128 |
+
|
129 |
+
# upsample to the original sampling rate
|
130 |
+
y = resample_poly(y, fs_ori, fs_down)
|
131 |
+
|
132 |
+
if len(y) != len(data):
|
133 |
+
y = align_length(data, y)
|
134 |
+
return y
|
135 |
+
|
136 |
+
|
137 |
+
def limit(integer, high, low):
|
138 |
+
if integer > high:
|
139 |
+
return high
|
140 |
+
elif integer < low:
|
141 |
+
return low
|
142 |
+
else:
|
143 |
+
return int(integer)
|
144 |
+
|
145 |
+
|
146 |
+
def lowpass(data, highcut, fs, order=5, _type="butter"):
|
147 |
+
"""
|
148 |
+
:param data: np.float32 type 1d time numpy array, (samples,) , can not be (samples, 1) !!!!!!!!!!!!
|
149 |
+
:param highcut: cutoff frequency
|
150 |
+
:param fs: sample rate of the original data
|
151 |
+
:param order: order of the filter
|
152 |
+
:return: filtered data, (samples,)
|
153 |
+
"""
|
154 |
+
|
155 |
+
if len(list(data.shape)) != 1:
|
156 |
+
raise ValueError(
|
157 |
+
"Error (chebyshev_lowpass_filter): Data "
|
158 |
+
+ str(data.shape)
|
159 |
+
+ " should be type 1d time array, (samples,) , can not be (samples, 1)"
|
160 |
+
)
|
161 |
+
|
162 |
+
if _type in "butter":
|
163 |
+
order = limit(order, high=10, low=2)
|
164 |
+
return lowpass_filter(
|
165 |
+
x=data, highcut=int(highcut), fs=fs, order=order, ftype="butter"
|
166 |
+
)
|
167 |
+
elif _type in "cheby1":
|
168 |
+
order = limit(order, high=10, low=2)
|
169 |
+
return lowpass_filter(
|
170 |
+
x=data, highcut=int(highcut), fs=fs, order=order, ftype="cheby1"
|
171 |
+
)
|
172 |
+
elif _type in "ellip":
|
173 |
+
order = limit(order, high=10, low=2)
|
174 |
+
return lowpass_filter(
|
175 |
+
x=data, highcut=int(highcut), fs=fs, order=order, ftype="ellip"
|
176 |
+
)
|
177 |
+
elif _type in "bessel":
|
178 |
+
order = limit(order, high=10, low=2)
|
179 |
+
return lowpass_filter(
|
180 |
+
x=data, highcut=int(highcut), fs=fs, order=order, ftype="bessel"
|
181 |
+
)
|
182 |
+
# elif(_type in "stft"):
|
183 |
+
# return stft_hard_lowpass(data, lowpass_ratio=highcut / int(fs / 2))
|
184 |
+
# elif(_type in "stft_hard"):
|
185 |
+
# return stft_hard_lowpass_v0(data, lowpass_ratio=highcut / int(fs / 2))
|
186 |
+
else:
|
187 |
+
raise ValueError("Error: Unexpected filter type " + _type)
|
188 |
+
|
189 |
+
|
190 |
+
def bandpass(data, lowcut, highcut, fs, order=5, _type="butter"):
|
191 |
+
"""
|
192 |
+
:param data: np.float32 type 1d time numpy array, (samples,) , can not be (samples, 1) !!!!!!!!!!!!
|
193 |
+
:param lowcut: low cutoff frequency
|
194 |
+
:param highcut: high cutoff frequency
|
195 |
+
:param fs: sample rate of the original data
|
196 |
+
:param order: order of the filter
|
197 |
+
:param _type: type of filter
|
198 |
+
:return: filtered data, (samples,)
|
199 |
+
"""
|
200 |
+
if len(list(data.shape)) != 1:
|
201 |
+
raise ValueError(
|
202 |
+
"Error (chebyshev_lowpass_filter): Data "
|
203 |
+
+ str(data.shape)
|
204 |
+
+ " should be type 1d time array, (samples,) , can not be (samples, 1)"
|
205 |
+
)
|
206 |
+
if _type in "butter":
|
207 |
+
order = limit(order, high=10, low=2)
|
208 |
+
return bandpass_filter(
|
209 |
+
x=data,
|
210 |
+
lowcut=int(lowcut),
|
211 |
+
highcut=int(highcut),
|
212 |
+
fs=fs,
|
213 |
+
order=order,
|
214 |
+
ftype="butter",
|
215 |
+
)
|
216 |
+
elif _type in "cheby1":
|
217 |
+
order = limit(order, high=10, low=2)
|
218 |
+
return bandpass_filter(
|
219 |
+
x=data,
|
220 |
+
lowcut=int(lowcut),
|
221 |
+
highcut=int(highcut),
|
222 |
+
fs=fs,
|
223 |
+
order=order,
|
224 |
+
ftype="cheby1",
|
225 |
+
)
|
226 |
+
# elif(_type in "cheby2"):
|
227 |
+
# return bandpass_filter(x=data,lowcut=int(lowcut),highcut=int(highcut), fs=fs, order=order,ftype="cheby2")
|
228 |
+
elif _type in "ellip":
|
229 |
+
order = limit(order, high=10, low=2)
|
230 |
+
return bandpass_filter(
|
231 |
+
x=data,
|
232 |
+
lowcut=int(lowcut),
|
233 |
+
highcut=int(highcut),
|
234 |
+
fs=fs,
|
235 |
+
order=order,
|
236 |
+
ftype="ellip",
|
237 |
+
)
|
238 |
+
elif _type in "bessel":
|
239 |
+
order = limit(order, high=10, low=2)
|
240 |
+
return bandpass_filter(
|
241 |
+
x=data,
|
242 |
+
lowcut=int(lowcut),
|
243 |
+
highcut=int(highcut),
|
244 |
+
fs=fs,
|
245 |
+
order=order,
|
246 |
+
ftype="bessel",
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
raise ValueError("Error: Unexpected filter type " + _type)
|
pipeline.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import yaml
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
|
10 |
+
from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
|
11 |
+
from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
|
12 |
+
from audiosr.utils import (
|
13 |
+
default_audioldm_config,
|
14 |
+
download_checkpoint,
|
15 |
+
read_audio_file,
|
16 |
+
lowpass_filtering_prepare_inference,
|
17 |
+
wav_feature_extraction,
|
18 |
+
)
|
19 |
+
import os
|
20 |
+
|
21 |
+
|
22 |
+
def seed_everything(seed):
|
23 |
+
import random, os
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
|
27 |
+
random.seed(seed)
|
28 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
29 |
+
np.random.seed(seed)
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
torch.cuda.manual_seed(seed)
|
32 |
+
torch.backends.cudnn.deterministic = True
|
33 |
+
torch.backends.cudnn.benchmark = True
|
34 |
+
|
35 |
+
|
36 |
+
def text2phoneme(data):
|
37 |
+
return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])
|
38 |
+
|
39 |
+
|
40 |
+
def text_to_filename(text):
|
41 |
+
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
|
42 |
+
|
43 |
+
|
44 |
+
def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
|
45 |
+
norm_mean = -4.2677393
|
46 |
+
norm_std = 4.5689974
|
47 |
+
|
48 |
+
if sampling_rate != 16000:
|
49 |
+
waveform_16k = torchaudio.functional.resample(
|
50 |
+
waveform, orig_freq=sampling_rate, new_freq=16000
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
waveform_16k = waveform
|
54 |
+
|
55 |
+
waveform_16k = waveform_16k - waveform_16k.mean()
|
56 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
57 |
+
waveform_16k,
|
58 |
+
htk_compat=True,
|
59 |
+
sample_frequency=16000,
|
60 |
+
use_energy=False,
|
61 |
+
window_type="hanning",
|
62 |
+
num_mel_bins=128,
|
63 |
+
dither=0.0,
|
64 |
+
frame_shift=10,
|
65 |
+
)
|
66 |
+
|
67 |
+
TARGET_LEN = log_mel_spec.size(0)
|
68 |
+
|
69 |
+
# cut and pad
|
70 |
+
n_frames = fbank.shape[0]
|
71 |
+
p = TARGET_LEN - n_frames
|
72 |
+
if p > 0:
|
73 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
74 |
+
fbank = m(fbank)
|
75 |
+
elif p < 0:
|
76 |
+
fbank = fbank[:TARGET_LEN, :]
|
77 |
+
|
78 |
+
fbank = (fbank - norm_mean) / (norm_std * 2)
|
79 |
+
|
80 |
+
return {"ta_kaldi_fbank": fbank} # [1024, 128]
|
81 |
+
|
82 |
+
|
83 |
+
def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
|
84 |
+
log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
|
85 |
+
|
86 |
+
batch = {
|
87 |
+
"waveform": torch.FloatTensor(waveform),
|
88 |
+
"stft": torch.FloatTensor(stft),
|
89 |
+
"log_mel_spec": torch.FloatTensor(log_mel_spec),
|
90 |
+
"sampling_rate": 48000,
|
91 |
+
}
|
92 |
+
|
93 |
+
# print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())
|
94 |
+
|
95 |
+
batch.update(lowpass_filtering_prepare_inference(batch))
|
96 |
+
|
97 |
+
assert "waveform_lowpass" in batch.keys()
|
98 |
+
lowpass_mel, lowpass_stft = wav_feature_extraction(
|
99 |
+
batch["waveform_lowpass"], target_frame
|
100 |
+
)
|
101 |
+
batch["lowpass_mel"] = lowpass_mel
|
102 |
+
|
103 |
+
for k in batch.keys():
|
104 |
+
if type(batch[k]) == torch.Tensor:
|
105 |
+
batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)
|
106 |
+
|
107 |
+
return batch, duration
|
108 |
+
|
109 |
+
|
110 |
+
def round_up_duration(duration):
|
111 |
+
return int(round(duration / 2.5) + 1) * 2.5
|
112 |
+
|
113 |
+
|
114 |
+
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
115 |
+
if device is None or device == "auto":
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
device = torch.device("cuda:0")
|
118 |
+
elif torch.backends.mps.is_available():
|
119 |
+
device = torch.device("mps")
|
120 |
+
else:
|
121 |
+
device = torch.device("cpu")
|
122 |
+
|
123 |
+
print("Loading AudioSR: %s" % model_name)
|
124 |
+
print("Loading model on %s" % device)
|
125 |
+
|
126 |
+
ckpt_path = download_checkpoint(model_name)
|
127 |
+
|
128 |
+
if config is not None:
|
129 |
+
assert type(config) is str
|
130 |
+
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
|
131 |
+
else:
|
132 |
+
config = default_audioldm_config(model_name)
|
133 |
+
|
134 |
+
# # Use text as condition instead of using waveform during training
|
135 |
+
config["model"]["params"]["device"] = device
|
136 |
+
# config["model"]["params"]["cond_stage_key"] = "text"
|
137 |
+
|
138 |
+
# No normalization here
|
139 |
+
latent_diffusion = LatentDiffusion(**config["model"]["params"])
|
140 |
+
|
141 |
+
resume_from_checkpoint = ckpt_path
|
142 |
+
|
143 |
+
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
|
144 |
+
|
145 |
+
latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
|
146 |
+
|
147 |
+
latent_diffusion.eval()
|
148 |
+
latent_diffusion = latent_diffusion.to(device)
|
149 |
+
|
150 |
+
return latent_diffusion
|
151 |
+
|
152 |
+
|
153 |
+
def super_resolution(
|
154 |
+
latent_diffusion,
|
155 |
+
input_file,
|
156 |
+
seed=42,
|
157 |
+
ddim_steps=200,
|
158 |
+
guidance_scale=3.5,
|
159 |
+
latent_t_per_second=12.8,
|
160 |
+
config=None,
|
161 |
+
):
|
162 |
+
seed_everything(int(seed))
|
163 |
+
waveform = None
|
164 |
+
|
165 |
+
batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
|
166 |
+
|
167 |
+
with torch.no_grad():
|
168 |
+
waveform = latent_diffusion.generate_batch(
|
169 |
+
batch,
|
170 |
+
unconditional_guidance_scale=guidance_scale,
|
171 |
+
ddim_steps=ddim_steps,
|
172 |
+
duration=duration,
|
173 |
+
)
|
174 |
+
|
175 |
+
return waveform
|