import contextlib import importlib from huggingface_hub import hf_hub_download from inspect import isfunction import os import soundfile as sf import time import wave import progressbar def read_list(fname): result = [] with open(fname, "r", encoding="utf-8") as f: for each in f.readlines(): each = each.strip('\n') result.append(each) return result def get_duration(fname): with contextlib.closing(wave.open(fname, "r")) as f: frames = f.getnframes() rate = f.getframerate() return frames / float(rate) def get_bit_depth(fname): with contextlib.closing(wave.open(fname, "r")) as f: bit_depth = f.getsampwidth() * 8 return bit_depth def get_time(): t = time.localtime() return time.strftime("%d_%m_%Y_%H_%M_%S", t) def seed_everything(seed): import random, os import numpy as np import torch random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def save_wave(waveform, savepath, name="outwav", samplerate=16000): if type(name) is not list: name = [name] * waveform.shape[0] for i in range(waveform.shape[0]): if(waveform.shape[0] > 1): fname = "%s_%s.wav" % ( os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0], i, ) else: fname = "%s.wav" % os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0] # Avoid the file name too long to be saved if len(fname) > 255: fname = f"{hex(hash(fname))}.wav" path = os.path.join( savepath, fname ) print("Save audio to %s" % path) sf.write(path, waveform[i, 0], samplerate=samplerate) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") try: return get_obj_from_str(config["target"])(**config.get("params", dict())) except: import ipdb ipdb.set_trace() def default_audioldm_config(model_name="audioldm2-full"): basic_config = get_basic_config() if("-large-" in model_name): basic_config["model"]["params"]["unet_config"]["params"]["context_dim"] = [768, 1024, None] basic_config["model"]["params"]["unet_config"]["params"]["transformer_depth"] = 2 if("-speech-" in model_name): basic_config["model"]["params"]["unet_config"]["params"]["context_dim"] = [768] basic_config["model"]["params"]["cond_stage_config"] = { "crossattn_audiomae_generated": { "cond_stage_key": "all", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond", "params": { "always_output_audiomae_gt": False, "learnable": True, "use_gt_mae_output": True, "use_gt_mae_prob": 1, "base_learning_rate": 0.0002, "sequence_gen_length": 512, "use_warmup": True, "sequence_input_key": [ "film_clap_cond1", "crossattn_vits_phoneme" ], "sequence_input_embed_dim": [ 512, 192 ], "batchsize": 16, "cond_stage_config": { "film_clap_cond1": { "cond_stage_key": "text", "conditioning_key": "film", "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2", "params": { "sampling_rate": 48000, "embed_mode": "text", "amodel": "HTSAT-base" } }, "crossattn_vits_phoneme": { "cond_stage_key": "phoneme_idx", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.PhonemeEncoder", "params": { "vocabs_size": 183, "pad_token_id": 0, "pad_length": 310 } }, "crossattn_audiomae_pooled": { "cond_stage_key": "ta_kaldi_fbank", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.AudioMAEConditionCTPoolRand", "params": { "regularization": False, "no_audiomae_mask": True, "time_pooling_factors": [ 1 ], "freq_pooling_factors": [ 1 ], "eval_time_pooling": 1, "eval_freq_pooling": 1, "mask_ratio": 0 } } } } } } if("48k" in model_name): basic_config=get_audioldm_48k_config() if("t5" in model_name): basic_config=get_audioldm_crossattn_t5_config() return basic_config class MyProgressBar: def __init__(self): self.pbar = None def __call__(self, block_num, block_size, total_size): if not self.pbar: self.pbar = progressbar.ProgressBar(maxval=total_size) self.pbar.start() downloaded = block_num * block_size if downloaded < total_size: self.pbar.update(downloaded) else: self.pbar.finish() def download_checkpoint(checkpoint_name="audioldm2-full"): if("audioldm2-speech" in checkpoint_name): model_id = "haoheliu/audioldm2-speech" else: model_id = "haoheliu/%s" % checkpoint_name checkpoint_path = hf_hub_download( repo_id=model_id, filename=checkpoint_name+".pth" ) return checkpoint_path def get_basic_config(): return { "log_directory": "./log/audiomae_pred", "precision": "high", "data": { "train": [ "audiocaps", "audioset", "wavcaps", "audiostock_music_250k", "free_to_use_sounds", "epidemic_sound_effects", "vggsound", "million_song_dataset", ], "val": "audiocaps", "test": "audiocaps", "class_label_indices": "audioset", "dataloader_add_ons": [ "extract_kaldi_fbank_feature", "extract_vits_phoneme_and_flant5_text", "waveform_rs_48k", ], }, "variables": { "sampling_rate": 16000, "mel_bins": 64, "latent_embed_dim": 8, "latent_t_size": 256, "latent_f_size": 16, "in_channels": 8, "optimize_ddpm_parameter": True, "warmup_steps": 5000, }, "step": { "validation_every_n_epochs": 1, "save_checkpoint_every_n_steps": 5000, "limit_val_batches": 10, "max_steps": 1500000, "save_top_k": 2, }, "preprocessing": { "audio": { "sampling_rate": 16000, "max_wav_value": 32768, "duration": 10.24, }, "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, "mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000}, }, "augmentation": {"mixup": 0}, "model": { # "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion", "params": { "first_stage_config": { "base_learning_rate": 0.000008, "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL", "params": { "sampling_rate": 16000, "batchsize": 4, "monitor": "val/rec_loss", "image_key": "fbank", "subband": 1, "embed_dim": 8, "time_shuffle": 1, "lossconfig": { "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", "params": { "disc_start": 50001, "kl_weight": 1000, "disc_weight": 0.5, "disc_in_channels": 1, }, }, "ddconfig": { "double_z": True, "mel_bins": 64, "z_channels": 8, "resolution": 256, "downsample_time": False, "in_channels": 1, "out_ch": 1, "ch": 128, "ch_mult": [1, 2, 4], "num_res_blocks": 2, "attn_resolutions": [], "dropout": 0, }, }, }, "base_learning_rate": 0.0001, "warmup_steps": 5000, "optimize_ddpm_parameter": True, "sampling_rate": 16000, "batchsize": 16, "linear_start": 0.0015, "linear_end": 0.0195, "num_timesteps_cond": 1, "log_every_t": 200, "timesteps": 1000, "unconditional_prob_cfg": 0.1, "parameterization": "eps", "first_stage_key": "fbank", "latent_t_size": 256, "latent_f_size": 16, "channels": 8, "monitor": "val/loss_simple_ema", "scale_by_std": True, "unet_config": { "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", "params": { "image_size": 64, "context_dim": [768, 1024], "in_channels": 8, "out_channels": 8, "model_channels": 128, "attention_resolutions": [8, 4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 3, 5], "num_head_channels": 32, "use_spatial_transformer": True, "transformer_depth": 1, }, }, "evaluation_params": { "unconditional_guidance_scale": 3.5, "ddim_sampling_steps": 200, "n_candidates_per_samples": 3, }, "cond_stage_config": { "crossattn_audiomae_generated": { "cond_stage_key": "all", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond", "params": { "always_output_audiomae_gt": False, "learnable": True, "device": "cuda", "use_gt_mae_output": True, "use_gt_mae_prob": 0.0, "base_learning_rate": 0.0002, "sequence_gen_length": 8, "use_warmup": True, "sequence_input_key": [ "film_clap_cond1", "crossattn_flan_t5", ], "sequence_input_embed_dim": [512, 1024], "batchsize": 16, "cond_stage_config": { "film_clap_cond1": { "cond_stage_key": "text", "conditioning_key": "film", "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2", "params": { "sampling_rate": 48000, "embed_mode": "text", "amodel": "HTSAT-base", }, }, "crossattn_flan_t5": { "cond_stage_key": "text", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState", }, "crossattn_audiomae_pooled": { "cond_stage_key": "ta_kaldi_fbank", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.AudioMAEConditionCTPoolRand", "params": { "regularization": False, "no_audiomae_mask": True, "time_pooling_factors": [8], "freq_pooling_factors": [8], "eval_time_pooling": 8, "eval_freq_pooling": 8, "mask_ratio": 0, }, }, }, }, }, "crossattn_flan_t5": { "cond_stage_key": "text", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState", }, }, }, }, } def get_audioldm_48k_config(): return { "variables": { "sampling_rate": 48000, "latent_embed_dim": 16, "mel_bins": 256, "latent_t_size": 128, "latent_f_size": 32, "in_channels": 16, "optimize_ddpm_parameter": True, "warmup_steps": 5000 }, "step": { "validation_every_n_epochs": 1, "save_checkpoint_every_n_steps": 5000, "limit_val_batches": 10, "max_steps": 1500000, "save_top_k": 2 }, "preprocessing": { "audio": { "sampling_rate": 48000, "max_wav_value": 32768, "duration": 10.24 }, "stft": { "filter_length": 2048, "hop_length": 480, "win_length": 2048 }, "mel": { "n_mel_channels": 256, "mel_fmin": 20, "mel_fmax": 24000 } }, "augmentation": { "mixup": 0 }, "model": { "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion", "params": { "first_stage_config": { "base_learning_rate": 0.000008, "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL", "params": { "sampling_rate": 48000, "batchsize": 4, "monitor": "val/rec_loss", "image_key": "fbank", "subband": 1, "embed_dim": 16, "time_shuffle": 1, "lossconfig": { "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", "params": { "disc_start": 50001, "kl_weight": 1000, "disc_weight": 0.5, "disc_in_channels": 1 } }, "ddconfig": { "double_z": True, "mel_bins": 256, "z_channels": 16, "resolution": 256, "downsample_time": False, "in_channels": 1, "out_ch": 1, "ch": 128, "ch_mult": [ 1, 2, 4, 8 ], "num_res_blocks": 2, "attn_resolutions": [], "dropout": 0 } } }, "base_learning_rate": 0.0001, "warmup_steps": 5000, "optimize_ddpm_parameter": True, "sampling_rate": 48000, "batchsize": 16, "linear_start": 0.0015, "linear_end": 0.0195, "num_timesteps_cond": 1, "log_every_t": 200, "timesteps": 1000, "unconditional_prob_cfg": 0.1, "parameterization": "eps", "first_stage_key": "fbank", "latent_t_size": 128, "latent_f_size": 32, "channels": 16, "monitor": "val/loss_simple_ema", "scale_by_std": True, "unet_config": { "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", "params": { "image_size": 64, "extra_film_condition_dim": 512, "context_dim": [ None ], "in_channels": 16, "out_channels": 16, "model_channels": 128, "attention_resolutions": [ 8, 4, 2 ], "num_res_blocks": 2, "channel_mult": [ 1, 2, 3, 5 ], "num_head_channels": 32, "use_spatial_transformer": True, "transformer_depth": 1 } }, "evaluation_params": { "unconditional_guidance_scale": 3.5, "ddim_sampling_steps": 200, "n_candidates_per_samples": 3 }, "cond_stage_config": { "film_clap_cond1": { "cond_stage_key": "text", "conditioning_key": "film", "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2", "params": { "sampling_rate": 48000, "embed_mode": "text", "amodel": "HTSAT-base" } } } } } } def get_audioldm_crossattn_t5_config(): return { "variables": { "sampling_rate": 16000, "mel_bins": 64, "latent_embed_dim": 8, "latent_t_size": 256, "latent_f_size": 16, "in_channels": 8, "optimize_ddpm_parameter": True, "warmup_steps": 5000 }, "step": { "validation_every_n_epochs": 1, "save_checkpoint_every_n_steps": 5000, "max_steps": 1500000, "save_top_k": 2 }, "preprocessing": { "audio": { "sampling_rate": 16000, "max_wav_value": 32768, "duration": 10.24 }, "stft": { "filter_length": 1024, "hop_length": 160, "win_length": 1024 }, "mel": { "n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000 } }, "augmentation": { "mixup": 0 }, "model": { "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion", "params": { "first_stage_config": { "base_learning_rate": 0.000008, "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL", "params": { "sampling_rate": 16000, "batchsize": 4, "monitor": "val/rec_loss", "image_key": "fbank", "subband": 1, "embed_dim": 8, "time_shuffle": 1, "lossconfig": { "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", "params": { "disc_start": 50001, "kl_weight": 1000, "disc_weight": 0.5, "disc_in_channels": 1 } }, "ddconfig": { "double_z": True, "mel_bins": 64, "z_channels": 8, "resolution": 256, "downsample_time": False, "in_channels": 1, "out_ch": 1, "ch": 128, "ch_mult": [ 1, 2, 4 ], "num_res_blocks": 2, "attn_resolutions": [], "dropout": 0 } } }, "base_learning_rate": 0.0001, "warmup_steps": 5000, "optimize_ddpm_parameter": True, "sampling_rate": 16000, "batchsize": 16, "linear_start": 0.0015, "linear_end": 0.0195, "num_timesteps_cond": 1, "log_every_t": 200, "timesteps": 1000, "unconditional_prob_cfg": 0.1, "parameterization": "eps", "first_stage_key": "fbank", "latent_t_size": 256, "latent_f_size": 16, "channels": 8, "monitor": "val/loss_simple_ema", "scale_by_std": True, "unet_config": { "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", "params": { "image_size": 64, "context_dim": [ 1024 ], "in_channels": 8, "out_channels": 8, "model_channels": 128, "attention_resolutions": [ 8, 4, 2 ], "num_res_blocks": 2, "channel_mult": [ 1, 2, 3, 5 ], "num_head_channels": 32, "use_spatial_transformer": True, "transformer_depth": 1 } }, "evaluation_params": { "unconditional_guidance_scale": 3.5, "ddim_sampling_steps": 200, "n_candidates_per_samples": 3 }, "cond_stage_config": { "crossattn_flan_t5": { "cond_stage_key": "text", "conditioning_key": "crossattn", "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState" } } } } }