namelessai commited on
Commit
cb25d6b
·
verified ·
1 Parent(s): 5167d22

upload main files

Browse files
Files changed (4) hide show
  1. __init__.py +2 -0
  2. __main__.py +123 -0
  3. lowpass.py +249 -0
  4. pipeline.py +175 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2
+ from .pipeline import *
__main__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import torch
4
+ import logging
5
+ from audiosr import super_resolution, build_model, save_wave, get_time, read_list
6
+ import argparse
7
+
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
9
+ matplotlib_logger = logging.getLogger('matplotlib')
10
+ matplotlib_logger.setLevel(logging.WARNING)
11
+
12
+ parser = argparse.ArgumentParser()
13
+
14
+ parser.add_argument(
15
+ "-i",
16
+ "--input_audio_file",
17
+ type=str,
18
+ required=False,
19
+ help="Input audio file for audio super resolution",
20
+ )
21
+
22
+ parser.add_argument(
23
+ "-il",
24
+ "--input_file_list",
25
+ type=str,
26
+ required=False,
27
+ default="",
28
+ help="A file that contains all audio files that need to perform audio super resolution",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "-s",
33
+ "--save_path",
34
+ type=str,
35
+ required=False,
36
+ help="The path to save model output",
37
+ default="./output",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--model_name",
42
+ type=str,
43
+ required=False,
44
+ help="The checkpoint you gonna use",
45
+ default="basic",
46
+ choices=["basic","speech"]
47
+ )
48
+
49
+ parser.add_argument(
50
+ "-d",
51
+ "--device",
52
+ type=str,
53
+ required=False,
54
+ help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
55
+ default="auto",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--ddim_steps",
60
+ type=int,
61
+ required=False,
62
+ default=50,
63
+ help="The sampling step for DDIM",
64
+ )
65
+
66
+ parser.add_argument(
67
+ "-gs",
68
+ "--guidance_scale",
69
+ type=float,
70
+ required=False,
71
+ default=3.5,
72
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--seed",
77
+ type=int,
78
+ required=False,
79
+ default=42,
80
+ help="Change this value (any integer number) will lead to a different generation result.",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--suffix",
85
+ type=str,
86
+ required=False,
87
+ help="Suffix for the output file",
88
+ default="_AudioSR_Processed_48K",
89
+ )
90
+
91
+ args = parser.parse_args()
92
+ torch.set_float32_matmul_precision("high")
93
+ save_path = os.path.join(args.save_path, get_time())
94
+
95
+ assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"
96
+
97
+ input_file = args.input_audio_file
98
+ random_seed = args.seed
99
+ sample_rate=48000
100
+ latent_t_per_second=12.8
101
+ guidance_scale = args.guidance_scale
102
+
103
+ os.makedirs(save_path, exist_ok=True)
104
+ audiosr = build_model(model_name=args.model_name, device=args.device)
105
+
106
+ if(args.input_file_list):
107
+ print("Generate audio based on the text prompts in %s" % args.input_file_list)
108
+ files_todo = read_list(args.input_file_list)
109
+ else:
110
+ files_todo = [input_file]
111
+
112
+ for input_file in files_todo:
113
+ name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
114
+
115
+ waveform = super_resolution(
116
+ audiosr,
117
+ input_file,
118
+ seed=random_seed,
119
+ guidance_scale=guidance_scale,
120
+ ddim_steps=args.ddim_steps,
121
+ latent_t_per_second=latent_t_per_second
122
+ )
123
+ save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)
lowpass.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.signal import butter, lfilter
2
+ import torch
3
+ from scipy import signal
4
+ import librosa
5
+ import numpy as np
6
+
7
+ from scipy.signal import sosfiltfilt
8
+ from scipy.signal import butter, cheby1, cheby2, ellip, bessel
9
+ from scipy.signal import resample_poly
10
+
11
+
12
+ def align_length(x=None, y=None, Lx=None):
13
+ """align the length of y to that of x
14
+
15
+ Args:
16
+ x (np.array): reference signal
17
+ y (np.array): the signal needs to be length aligned
18
+
19
+ Return:
20
+ yy (np.array): signal with the same length as x
21
+ """
22
+ assert y is not None
23
+
24
+ if Lx is None:
25
+ Lx = len(x)
26
+ Ly = len(y)
27
+
28
+ if Lx == Ly:
29
+ return y
30
+ elif Lx > Ly:
31
+ # pad y with zeros
32
+ return np.pad(y, (0, Lx - Ly), mode="constant")
33
+ else:
34
+ # cut y
35
+ return y[:Lx]
36
+
37
+
38
+ def bandpass_filter(x, lowcut, highcut, fs, order, ftype):
39
+ """process input signal x using bandpass filter
40
+
41
+ Args:
42
+ x (np.array): input signal
43
+ lowcut (float): low cutoff frequency
44
+ highcut (float): high cutoff frequency
45
+ order (int): the order of filter
46
+ ftype (string): type of filter
47
+ ['butter', 'cheby1', 'cheby2', 'ellip', 'bessel']
48
+
49
+ Return:
50
+ y (np.array): filtered signal
51
+ """
52
+ nyq = 0.5 * fs
53
+ lo = lowcut / nyq
54
+ hi = highcut / nyq
55
+
56
+ if ftype == "butter":
57
+ # b, a = butter(order, [lo, hi], btype='band')
58
+ sos = butter(order, [lo, hi], btype="band", output="sos")
59
+ elif ftype == "cheby1":
60
+ sos = cheby1(order, 0.1, [lo, hi], btype="band", output="sos")
61
+ elif ftype == "cheby2":
62
+ sos = cheby2(order, 60, [lo, hi], btype="band", output="sos")
63
+ elif ftype == "ellip":
64
+ sos = ellip(order, 0.1, 60, [lo, hi], btype="band", output="sos")
65
+ elif ftype == "bessel":
66
+ sos = bessel(order, [lo, hi], btype="band", output="sos")
67
+ else:
68
+ raise Exception(f"The bandpass filter {ftype} is not supported!")
69
+
70
+ # y = lfilter(b, a, x)
71
+ y = sosfiltfilt(sos, x)
72
+
73
+ if len(y) != len(x):
74
+ y = align_length(x, y)
75
+ return y
76
+
77
+
78
+ def lowpass_filter(x, highcut, fs, order, ftype):
79
+ """process input signal x using lowpass filter
80
+
81
+ Args:
82
+ x (np.array): input signal
83
+ highcut (float): high cutoff frequency
84
+ order (int): the order of filter
85
+ ftype (string): type of filter
86
+ ['butter', 'cheby1', 'cheby2', 'ellip', 'bessel']
87
+
88
+ Return:
89
+ y (np.array): filtered signal
90
+ """
91
+ nyq = 0.5 * fs
92
+ hi = highcut / nyq
93
+
94
+ if ftype == "butter":
95
+ sos = butter(order, hi, btype="low", output="sos")
96
+ elif ftype == "cheby1":
97
+ sos = cheby1(order, 0.1, hi, btype="low", output="sos")
98
+ elif ftype == "cheby2":
99
+ sos = cheby2(order, 60, hi, btype="low", output="sos")
100
+ elif ftype == "ellip":
101
+ sos = ellip(order, 0.1, 60, hi, btype="low", output="sos")
102
+ elif ftype == "bessel":
103
+ sos = bessel(order, hi, btype="low", output="sos")
104
+ else:
105
+ raise Exception(f"The lowpass filter {ftype} is not supported!")
106
+
107
+ y = sosfiltfilt(sos, x)
108
+
109
+ if len(y) != len(x):
110
+ y = align_length(x, y)
111
+
112
+ y_len = len(y)
113
+
114
+ y = stft_hard_lowpass(y, hi, fs_ori=fs)
115
+
116
+ y = sosfiltfilt(sos, y)
117
+
118
+ if len(y) != y_len:
119
+ y = align_length(y=y, Lx=y_len)
120
+
121
+ return y
122
+
123
+
124
+ def stft_hard_lowpass(data, lowpass_ratio, fs_ori=44100):
125
+ fs_down = int(lowpass_ratio * fs_ori)
126
+ # downsample to the low sampling rate
127
+ y = resample_poly(data, fs_down, fs_ori)
128
+
129
+ # upsample to the original sampling rate
130
+ y = resample_poly(y, fs_ori, fs_down)
131
+
132
+ if len(y) != len(data):
133
+ y = align_length(data, y)
134
+ return y
135
+
136
+
137
+ def limit(integer, high, low):
138
+ if integer > high:
139
+ return high
140
+ elif integer < low:
141
+ return low
142
+ else:
143
+ return int(integer)
144
+
145
+
146
+ def lowpass(data, highcut, fs, order=5, _type="butter"):
147
+ """
148
+ :param data: np.float32 type 1d time numpy array, (samples,) , can not be (samples, 1) !!!!!!!!!!!!
149
+ :param highcut: cutoff frequency
150
+ :param fs: sample rate of the original data
151
+ :param order: order of the filter
152
+ :return: filtered data, (samples,)
153
+ """
154
+
155
+ if len(list(data.shape)) != 1:
156
+ raise ValueError(
157
+ "Error (chebyshev_lowpass_filter): Data "
158
+ + str(data.shape)
159
+ + " should be type 1d time array, (samples,) , can not be (samples, 1)"
160
+ )
161
+
162
+ if _type in "butter":
163
+ order = limit(order, high=10, low=2)
164
+ return lowpass_filter(
165
+ x=data, highcut=int(highcut), fs=fs, order=order, ftype="butter"
166
+ )
167
+ elif _type in "cheby1":
168
+ order = limit(order, high=10, low=2)
169
+ return lowpass_filter(
170
+ x=data, highcut=int(highcut), fs=fs, order=order, ftype="cheby1"
171
+ )
172
+ elif _type in "ellip":
173
+ order = limit(order, high=10, low=2)
174
+ return lowpass_filter(
175
+ x=data, highcut=int(highcut), fs=fs, order=order, ftype="ellip"
176
+ )
177
+ elif _type in "bessel":
178
+ order = limit(order, high=10, low=2)
179
+ return lowpass_filter(
180
+ x=data, highcut=int(highcut), fs=fs, order=order, ftype="bessel"
181
+ )
182
+ # elif(_type in "stft"):
183
+ # return stft_hard_lowpass(data, lowpass_ratio=highcut / int(fs / 2))
184
+ # elif(_type in "stft_hard"):
185
+ # return stft_hard_lowpass_v0(data, lowpass_ratio=highcut / int(fs / 2))
186
+ else:
187
+ raise ValueError("Error: Unexpected filter type " + _type)
188
+
189
+
190
+ def bandpass(data, lowcut, highcut, fs, order=5, _type="butter"):
191
+ """
192
+ :param data: np.float32 type 1d time numpy array, (samples,) , can not be (samples, 1) !!!!!!!!!!!!
193
+ :param lowcut: low cutoff frequency
194
+ :param highcut: high cutoff frequency
195
+ :param fs: sample rate of the original data
196
+ :param order: order of the filter
197
+ :param _type: type of filter
198
+ :return: filtered data, (samples,)
199
+ """
200
+ if len(list(data.shape)) != 1:
201
+ raise ValueError(
202
+ "Error (chebyshev_lowpass_filter): Data "
203
+ + str(data.shape)
204
+ + " should be type 1d time array, (samples,) , can not be (samples, 1)"
205
+ )
206
+ if _type in "butter":
207
+ order = limit(order, high=10, low=2)
208
+ return bandpass_filter(
209
+ x=data,
210
+ lowcut=int(lowcut),
211
+ highcut=int(highcut),
212
+ fs=fs,
213
+ order=order,
214
+ ftype="butter",
215
+ )
216
+ elif _type in "cheby1":
217
+ order = limit(order, high=10, low=2)
218
+ return bandpass_filter(
219
+ x=data,
220
+ lowcut=int(lowcut),
221
+ highcut=int(highcut),
222
+ fs=fs,
223
+ order=order,
224
+ ftype="cheby1",
225
+ )
226
+ # elif(_type in "cheby2"):
227
+ # return bandpass_filter(x=data,lowcut=int(lowcut),highcut=int(highcut), fs=fs, order=order,ftype="cheby2")
228
+ elif _type in "ellip":
229
+ order = limit(order, high=10, low=2)
230
+ return bandpass_filter(
231
+ x=data,
232
+ lowcut=int(lowcut),
233
+ highcut=int(highcut),
234
+ fs=fs,
235
+ order=order,
236
+ ftype="ellip",
237
+ )
238
+ elif _type in "bessel":
239
+ order = limit(order, high=10, low=2)
240
+ return bandpass_filter(
241
+ x=data,
242
+ lowcut=int(lowcut),
243
+ highcut=int(highcut),
244
+ fs=fs,
245
+ order=order,
246
+ ftype="bessel",
247
+ )
248
+ else:
249
+ raise ValueError("Error: Unexpected filter type " + _type)
pipeline.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import yaml
5
+ import torch
6
+ import torchaudio
7
+ import numpy as np
8
+
9
+ import audiosr.latent_diffusion.modules.phoneme_encoder.text as text
10
+ from audiosr.latent_diffusion.models.ddpm import LatentDiffusion
11
+ from audiosr.latent_diffusion.util import get_vits_phoneme_ids_no_padding
12
+ from audiosr.utils import (
13
+ default_audioldm_config,
14
+ download_checkpoint,
15
+ read_audio_file,
16
+ lowpass_filtering_prepare_inference,
17
+ wav_feature_extraction,
18
+ )
19
+ import os
20
+
21
+
22
+ def seed_everything(seed):
23
+ import random, os
24
+ import numpy as np
25
+ import torch
26
+
27
+ random.seed(seed)
28
+ os.environ["PYTHONHASHSEED"] = str(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed(seed)
32
+ torch.backends.cudnn.deterministic = True
33
+ torch.backends.cudnn.benchmark = True
34
+
35
+
36
+ def text2phoneme(data):
37
+ return text._clean_text(re.sub(r"<.*?>", "", data), ["english_cleaners2"])
38
+
39
+
40
+ def text_to_filename(text):
41
+ return text.replace(" ", "_").replace("'", "_").replace('"', "_")
42
+
43
+
44
+ def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
45
+ norm_mean = -4.2677393
46
+ norm_std = 4.5689974
47
+
48
+ if sampling_rate != 16000:
49
+ waveform_16k = torchaudio.functional.resample(
50
+ waveform, orig_freq=sampling_rate, new_freq=16000
51
+ )
52
+ else:
53
+ waveform_16k = waveform
54
+
55
+ waveform_16k = waveform_16k - waveform_16k.mean()
56
+ fbank = torchaudio.compliance.kaldi.fbank(
57
+ waveform_16k,
58
+ htk_compat=True,
59
+ sample_frequency=16000,
60
+ use_energy=False,
61
+ window_type="hanning",
62
+ num_mel_bins=128,
63
+ dither=0.0,
64
+ frame_shift=10,
65
+ )
66
+
67
+ TARGET_LEN = log_mel_spec.size(0)
68
+
69
+ # cut and pad
70
+ n_frames = fbank.shape[0]
71
+ p = TARGET_LEN - n_frames
72
+ if p > 0:
73
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
74
+ fbank = m(fbank)
75
+ elif p < 0:
76
+ fbank = fbank[:TARGET_LEN, :]
77
+
78
+ fbank = (fbank - norm_mean) / (norm_std * 2)
79
+
80
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
81
+
82
+
83
+ def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
84
+ log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
85
+
86
+ batch = {
87
+ "waveform": torch.FloatTensor(waveform),
88
+ "stft": torch.FloatTensor(stft),
89
+ "log_mel_spec": torch.FloatTensor(log_mel_spec),
90
+ "sampling_rate": 48000,
91
+ }
92
+
93
+ # print(batch["waveform"].size(), batch["stft"].size(), batch["log_mel_spec"].size())
94
+
95
+ batch.update(lowpass_filtering_prepare_inference(batch))
96
+
97
+ assert "waveform_lowpass" in batch.keys()
98
+ lowpass_mel, lowpass_stft = wav_feature_extraction(
99
+ batch["waveform_lowpass"], target_frame
100
+ )
101
+ batch["lowpass_mel"] = lowpass_mel
102
+
103
+ for k in batch.keys():
104
+ if type(batch[k]) == torch.Tensor:
105
+ batch[k] = torch.FloatTensor(batch[k]).unsqueeze(0)
106
+
107
+ return batch, duration
108
+
109
+
110
+ def round_up_duration(duration):
111
+ return int(round(duration / 2.5) + 1) * 2.5
112
+
113
+
114
+ def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
115
+ if device is None or device == "auto":
116
+ if torch.cuda.is_available():
117
+ device = torch.device("cuda:0")
118
+ elif torch.backends.mps.is_available():
119
+ device = torch.device("mps")
120
+ else:
121
+ device = torch.device("cpu")
122
+
123
+ print("Loading AudioSR: %s" % model_name)
124
+ print("Loading model on %s" % device)
125
+
126
+ ckpt_path = download_checkpoint(model_name)
127
+
128
+ if config is not None:
129
+ assert type(config) is str
130
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
131
+ else:
132
+ config = default_audioldm_config(model_name)
133
+
134
+ # # Use text as condition instead of using waveform during training
135
+ config["model"]["params"]["device"] = device
136
+ # config["model"]["params"]["cond_stage_key"] = "text"
137
+
138
+ # No normalization here
139
+ latent_diffusion = LatentDiffusion(**config["model"]["params"])
140
+
141
+ resume_from_checkpoint = ckpt_path
142
+
143
+ checkpoint = torch.load(resume_from_checkpoint, map_location=device)
144
+
145
+ latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
146
+
147
+ latent_diffusion.eval()
148
+ latent_diffusion = latent_diffusion.to(device)
149
+
150
+ return latent_diffusion
151
+
152
+
153
+ def super_resolution(
154
+ latent_diffusion,
155
+ input_file,
156
+ seed=42,
157
+ ddim_steps=200,
158
+ guidance_scale=3.5,
159
+ latent_t_per_second=12.8,
160
+ config=None,
161
+ ):
162
+ seed_everything(int(seed))
163
+ waveform = None
164
+
165
+ batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
166
+
167
+ with torch.no_grad():
168
+ waveform = latent_diffusion.generate_batch(
169
+ batch,
170
+ unconditional_guidance_scale=guidance_scale,
171
+ ddim_steps=ddim_steps,
172
+ duration=duration,
173
+ )
174
+
175
+ return waveform