Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
import importlib | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
import torch | |
from inspect import isfunction | |
import os | |
import subprocess | |
import tempfile | |
import json | |
import soundfile as sf | |
import time | |
import wave | |
import torchaudio | |
import progressbar | |
from librosa.filters import mel as librosa_mel_fn | |
from audiosr.lowpass import lowpass | |
hann_window = {} | |
mel_basis = {} | |
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): | |
return torch.log(torch.clamp(x, min=clip_val) * C) | |
def dynamic_range_decompression_torch(x, C=1): | |
return torch.exp(x) / C | |
def spectral_normalize_torch(magnitudes): | |
output = dynamic_range_compression_torch(magnitudes) | |
return output | |
def spectral_de_normalize_torch(magnitudes): | |
output = dynamic_range_decompression_torch(magnitudes) | |
return output | |
def _locate_cutoff_freq(stft, percentile=0.97): | |
def _find_cutoff(x, percentile=0.95): | |
percentile = x[-1] * percentile | |
for i in range(1, x.shape[0]): | |
if x[-i] < percentile: | |
return x.shape[0] - i | |
return 0 | |
magnitude = torch.abs(stft) | |
energy = torch.cumsum(torch.sum(magnitude, dim=0), dim=0) | |
return _find_cutoff(energy, percentile) | |
def pad_wav(waveform, target_length): | |
waveform_length = waveform.shape[-1] | |
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length | |
if waveform_length == target_length: | |
return waveform | |
# Pad | |
temp_wav = np.zeros((1, target_length), dtype=np.float32) | |
rand_start = 0 | |
temp_wav[:, rand_start : rand_start + waveform_length] = waveform | |
return temp_wav | |
def lowpass_filtering_prepare_inference(dl_output): | |
waveform = dl_output["waveform"] # [1, samples] | |
sampling_rate = dl_output["sampling_rate"] | |
cutoff_freq = ( | |
_locate_cutoff_freq(dl_output["stft"], percentile=0.985) / 1024 | |
) * 24000 | |
# If the audio is almost empty. Give up processing | |
if(cutoff_freq < 1000): | |
cutoff_freq = 24000 | |
order = 8 | |
ftype = np.random.choice(["butter", "cheby1", "ellip", "bessel"]) | |
filtered_audio = lowpass( | |
waveform.numpy().squeeze(), | |
highcut=cutoff_freq, | |
fs=sampling_rate, | |
order=order, | |
_type=ftype, | |
) | |
filtered_audio = torch.FloatTensor(filtered_audio.copy()).unsqueeze(0) | |
if waveform.size(-1) <= filtered_audio.size(-1): | |
filtered_audio = filtered_audio[..., : waveform.size(-1)] | |
else: | |
filtered_audio = torch.functional.pad( | |
filtered_audio, (0, waveform.size(-1) - filtered_audio.size(-1)) | |
) | |
return {"waveform_lowpass": filtered_audio} | |
def mel_spectrogram_train(y): | |
global mel_basis, hann_window | |
sampling_rate = 48000 | |
filter_length = 2048 | |
hop_length = 480 | |
win_length = 2048 | |
n_mel = 256 | |
mel_fmin = 20 | |
mel_fmax = 24000 | |
if 24000 not in mel_basis: | |
mel = librosa_mel_fn(sr=sampling_rate, n_fft=filter_length, n_mels=n_mel, fmin=mel_fmin, fmax=mel_fmax) | |
mel_basis[str(mel_fmax) + "_" + str(y.device)] = ( | |
torch.from_numpy(mel).float().to(y.device) | |
) | |
hann_window[str(y.device)] = torch.hann_window(win_length).to(y.device) | |
y = torch.nn.functional.pad( | |
y.unsqueeze(1), | |
(int((filter_length - hop_length) / 2), int((filter_length - hop_length) / 2)), | |
mode="reflect", | |
) | |
y = y.squeeze(1) | |
stft_spec = torch.stft( | |
y, | |
filter_length, | |
hop_length=hop_length, | |
win_length=win_length, | |
window=hann_window[str(y.device)], | |
center=False, | |
pad_mode="reflect", | |
normalized=False, | |
onesided=True, | |
return_complex=True, | |
) | |
stft_spec = torch.abs(stft_spec) | |
mel = spectral_normalize_torch( | |
torch.matmul(mel_basis[str(mel_fmax) + "_" + str(y.device)], stft_spec) | |
) | |
return mel[0], stft_spec[0] | |
def pad_spec(log_mel_spec, target_frame): | |
n_frames = log_mel_spec.shape[0] | |
p = target_frame - n_frames | |
# cut and pad | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
log_mel_spec = m(log_mel_spec) | |
elif p < 0: | |
log_mel_spec = log_mel_spec[0:target_frame, :] | |
if log_mel_spec.size(-1) % 2 != 0: | |
log_mel_spec = log_mel_spec[..., :-1] | |
return log_mel_spec | |
def wav_feature_extraction(waveform, target_frame): | |
waveform = waveform[0, ...] | |
waveform = torch.FloatTensor(waveform) | |
log_mel_spec, stft = mel_spectrogram_train(waveform.unsqueeze(0)) | |
log_mel_spec = torch.FloatTensor(log_mel_spec.T) | |
stft = torch.FloatTensor(stft.T) | |
log_mel_spec, stft = pad_spec(log_mel_spec, target_frame), pad_spec( | |
stft, target_frame | |
) | |
return log_mel_spec, stft | |
def normalize_wav(waveform): | |
waveform = waveform - np.mean(waveform) | |
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) | |
return waveform * 0.5 | |
def read_wav_file(filename): | |
waveform, sr = torchaudio.load(filename) | |
duration = waveform.size(-1) / sr | |
if(duration > 10.24): | |
print("\033[93m {}\033[00m" .format("Warning: audio is longer than 10.24 seconds, may degrade the model performance. It's recommand to truncate your audio to 5.12 seconds before input to AudioSR to get the best performance.")) | |
if(duration % 5.12 != 0): | |
pad_duration = duration + (5.12 - duration % 5.12) | |
else: | |
pad_duration = duration | |
target_frame = int(pad_duration * 100) | |
waveform = torchaudio.functional.resample(waveform, sr, 48000) | |
waveform = waveform.numpy()[0, ...] | |
waveform = normalize_wav( | |
waveform | |
) # TODO rescaling the waveform will cause low LSD score | |
waveform = waveform[None, ...] | |
waveform = pad_wav(waveform, target_length=int(48000 * pad_duration)) | |
return waveform, target_frame, pad_duration | |
def read_audio_file(filename): | |
waveform, target_frame, duration = read_wav_file(filename) | |
log_mel_spec, stft = wav_feature_extraction(waveform, target_frame) | |
return log_mel_spec, stft, waveform, duration, target_frame | |
def read_list(fname): | |
result = [] | |
with open(fname, "r", encoding="utf-8") as f: | |
for each in f.readlines(): | |
each = each.strip("\n") | |
result.append(each) | |
return result | |
def get_duration(fname): | |
with contextlib.closing(wave.open(fname, "r")) as f: | |
frames = f.getnframes() | |
rate = f.getframerate() | |
return frames / float(rate) | |
def get_bit_depth(fname): | |
with contextlib.closing(wave.open(fname, "r")) as f: | |
bit_depth = f.getsampwidth() * 8 | |
return bit_depth | |
def get_time(): | |
t = time.localtime() | |
return time.strftime("%d_%m_%Y_%H_%M_%S", t) | |
def seed_everything(seed): | |
import random, os | |
import numpy as np | |
import torch | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
def strip_silence(orignal_path, input_path, output_path): | |
get_dur = subprocess.run([ | |
'ffprobe', | |
'-v', 'error', | |
'-select_streams', 'a:0', | |
'-show_entries', 'format=duration', | |
'-sexagesimal', | |
'-of', 'json', | |
orignal_path | |
], stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
duration = json.loads(get_dur.stdout)['format']['duration'] | |
subprocess.run([ | |
'ffmpeg', | |
'-y', | |
'-ss', '00:00:00', | |
'-i', input_path, | |
'-t', duration, | |
'-c', 'copy', | |
output_path | |
]) | |
os.remove(input_path) | |
def save_wave(waveform, inputpath, savepath, name="outwav", samplerate=16000): | |
if type(name) is not list: | |
name = [name] * waveform.shape[0] | |
for i in range(waveform.shape[0]): | |
if waveform.shape[0] > 1: | |
fname = "%s_%s.wav" % ( | |
os.path.basename(name[i]) | |
if (not ".wav" in name[i]) | |
else os.path.basename(name[i]).split(".")[0], | |
i, | |
) | |
else: | |
fname = ( | |
"%s.wav" % os.path.basename(name[i]) | |
if (not ".wav" in name[i]) | |
else os.path.basename(name[i]).split(".")[0] | |
) | |
# Avoid the file name too long to be saved | |
if len(fname) > 255: | |
fname = f"{hex(hash(fname))}.wav" | |
save_path = os.path.join(savepath, fname) | |
temp_path = os.path.join(tempfile.gettempdir(), fname) | |
print("\033[98m {}\033[00m" .format("Don't forget to try different seeds by setting --seed <int> so that AudioSR can have optimal performance on your hardware.")) | |
print("Save audio to %s." % save_path) | |
sf.write(temp_path, waveform[i, 0], samplerate=samplerate) | |
strip_silence(inputpath, temp_path, save_path) | |
def exists(x): | |
return x is not None | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def count_params(model, verbose=False): | |
total_params = sum(p.numel() for p in model.parameters()) | |
if verbose: | |
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") | |
return total_params | |
def get_obj_from_str(string, reload=False): | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def instantiate_from_config(config): | |
if not "target" in config: | |
if config == "__is_first_stage__": | |
return None | |
elif config == "__is_unconditional__": | |
return None | |
raise KeyError("Expected key `target` to instantiate.") | |
try: | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
except: | |
import ipdb | |
ipdb.set_trace() | |
def default_audioldm_config(model_name="basic"): | |
basic_config = get_basic_config() | |
return basic_config | |
class MyProgressBar: | |
def __init__(self): | |
self.pbar = None | |
def __call__(self, block_num, block_size, total_size): | |
if not self.pbar: | |
self.pbar = progressbar.ProgressBar(maxval=total_size) | |
self.pbar.start() | |
downloaded = block_num * block_size | |
if downloaded < total_size: | |
self.pbar.update(downloaded) | |
else: | |
self.pbar.finish() | |
def download_checkpoint(checkpoint_name="basic"): | |
if checkpoint_name == "basic": | |
model_id = "haoheliu/audiosr_basic" | |
checkpoint_path = hf_hub_download( | |
repo_id=model_id, filename="pytorch_model.bin" | |
) | |
elif checkpoint_name == "speech": | |
model_id = "haoheliu/audiosr_speech" | |
checkpoint_path = hf_hub_download( | |
repo_id=model_id, filename="pytorch_model.bin" | |
) | |
else: | |
raise ValueError("Invalid Model Name %s" % checkpoint_name) | |
return checkpoint_path | |
def get_basic_config(): | |
return { | |
"preprocessing": { | |
"audio": { | |
"sampling_rate": 48000, | |
"max_wav_value": 32768, | |
"duration": 10.24, | |
}, | |
"stft": {"filter_length": 2048, "hop_length": 480, "win_length": 2048}, | |
"mel": {"n_mel_channels": 256, "mel_fmin": 20, "mel_fmax": 24000}, | |
}, | |
"augmentation": {"mixup": 0.5}, | |
"model": { | |
"target": "audiosr.latent_diffusion.models.ddpm.LatentDiffusion", | |
"params": { | |
"first_stage_config": { | |
"base_learning_rate": 0.000008, | |
"target": "audiosr.latent_encoder.autoencoder.AutoencoderKL", | |
"params": { | |
"reload_from_ckpt": "/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/log/vae/vae_48k_256/ds_8_kl_1/checkpoints/ckpt-checkpoint-484999.ckpt", | |
"sampling_rate": 48000, | |
"batchsize": 4, | |
"monitor": "val/rec_loss", | |
"image_key": "fbank", | |
"subband": 1, | |
"embed_dim": 16, | |
"time_shuffle": 1, | |
"ddconfig": { | |
"double_z": True, | |
"mel_bins": 256, | |
"z_channels": 16, | |
"resolution": 256, | |
"downsample_time": False, | |
"in_channels": 1, | |
"out_ch": 1, | |
"ch": 128, | |
"ch_mult": [1, 2, 4, 8], | |
"num_res_blocks": 2, | |
"attn_resolutions": [], | |
"dropout": 0.1, | |
}, | |
}, | |
}, | |
"base_learning_rate": 0.0001, | |
"warmup_steps": 5000, | |
"optimize_ddpm_parameter": True, | |
"sampling_rate": 48000, | |
"batchsize": 16, | |
"beta_schedule": "cosine", | |
"linear_start": 0.0015, | |
"linear_end": 0.0195, | |
"num_timesteps_cond": 1, | |
"log_every_t": 200, | |
"timesteps": 1000, | |
"unconditional_prob_cfg": 0.1, | |
"parameterization": "v", | |
"first_stage_key": "fbank", | |
"latent_t_size": 128, | |
"latent_f_size": 32, | |
"channels": 16, | |
"monitor": "val/loss_simple_ema", | |
"scale_by_std": True, | |
"unet_config": { | |
"target": "audiosr.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", | |
"params": { | |
"image_size": 64, | |
"in_channels": 32, | |
"out_channels": 16, | |
"model_channels": 128, | |
"attention_resolutions": [8, 4, 2], | |
"num_res_blocks": 2, | |
"channel_mult": [1, 2, 3, 5], | |
"num_head_channels": 32, | |
"extra_sa_layer": True, | |
"use_spatial_transformer": True, | |
"transformer_depth": 1, | |
}, | |
}, | |
"evaluation_params": { | |
"unconditional_guidance_scale": 3.5, | |
"ddim_sampling_steps": 200, | |
"n_candidates_per_samples": 1, | |
}, | |
"cond_stage_config": { | |
"concat_lowpass_cond": { | |
"cond_stage_key": "lowpass_mel", | |
"conditioning_key": "concat", | |
"target": "audiosr.latent_diffusion.modules.encoders.modules.VAEFeatureExtract", | |
"params": { | |
"first_stage_config": { | |
"base_learning_rate": 0.000008, | |
"target": "audiosr.latent_encoder.autoencoder.AutoencoderKL", | |
"params": { | |
"sampling_rate": 48000, | |
"batchsize": 4, | |
"monitor": "val/rec_loss", | |
"image_key": "fbank", | |
"subband": 1, | |
"embed_dim": 16, | |
"time_shuffle": 1, | |
"ddconfig": { | |
"double_z": True, | |
"mel_bins": 256, | |
"z_channels": 16, | |
"resolution": 256, | |
"downsample_time": False, | |
"in_channels": 1, | |
"out_ch": 1, | |
"ch": 128, | |
"ch_mult": [1, 2, 4, 8], | |
"num_res_blocks": 2, | |
"attn_resolutions": [], | |
"dropout": 0.1, | |
}, | |
}, | |
} | |
}, | |
} | |
}, | |
}, | |
}, | |
} |