import os import argparse import yaml import torch from torch import autocast from tqdm import tqdm, trange from audioldm import LatentDiffusion, seed_everything from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file from audioldm.latent_diffusion.ddim import DDIMSampler from einops import repeat import os def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): text = [text] * batchsize if batchsize < 1: print("Warning: Batchsize must be at least 1. Batchsize is set to .") if(fbank is None): fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format else: fbank = torch.FloatTensor(fbank) fbank = fbank.expand(batchsize, 1024, 64) assert fbank.size(0) == batchsize stft = torch.zeros((batchsize, 1024, 512)) # Not used if(waveform is None): waveform = torch.zeros((batchsize, 160000)) # Not used else: waveform = torch.FloatTensor(waveform) waveform = waveform.expand(batchsize, -1) assert waveform.size(0) == batchsize fname = [""] * batchsize # Not used batch = ( fbank, stft, None, fname, waveform, text, ) return batch def round_up_duration(duration): return int(round(duration/2.5) + 1) * 2.5 def build_model( ckpt_path=None, config=None, model_name="audioldm-s-full" ): print("Load AudioLDM: %s", model_name) if(ckpt_path is None): ckpt_path = get_metadata()[model_name]["path"] if(not os.path.exists(ckpt_path)): download_checkpoint(model_name) if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") if config is not None: assert type(config) is str config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) else: config = default_audioldm_config(model_name) # Use text as condition instead of using waveform during training config["model"]["params"]["device"] = device config["model"]["params"]["cond_stage_key"] = "text" # No normalization here latent_diffusion = LatentDiffusion(**config["model"]["params"]) resume_from_checkpoint = ckpt_path checkpoint = torch.load(resume_from_checkpoint, map_location=device) latent_diffusion.load_state_dict(checkpoint["state_dict"]) latent_diffusion.eval() latent_diffusion = latent_diffusion.to(device) latent_diffusion.cond_stage_model.embed_mode = "text" return latent_diffusion def duration_to_latent_t_size(duration): return int(duration * 25.6) def set_cond_audio(latent_diffusion): latent_diffusion.cond_stage_key = "waveform" latent_diffusion.cond_stage_model.embed_mode="audio" return latent_diffusion def set_cond_text(latent_diffusion): latent_diffusion.cond_stage_key = "text" latent_diffusion.cond_stage_model.embed_mode="text" return latent_diffusion def text_to_audio( latent_diffusion, text, original_audio_file_path = None, seed=42, ddim_steps=200, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None, ): seed_everything(int(seed)) waveform = None if(original_audio_file_path is not None): waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) if(waveform is not None): print("Generate audio that has similar content as %s" % original_audio_file_path) latent_diffusion = set_cond_audio(latent_diffusion) else: print("Generate audio using text %s" % text) latent_diffusion = set_cond_text(latent_diffusion) with torch.no_grad(): waveform = latent_diffusion.generate_sample( [batch], unconditional_guidance_scale=guidance_scale, ddim_steps=ddim_steps, n_candidate_gen_per_text=n_candidate_gen_per_text, duration=duration, ) return waveform def style_transfer( latent_diffusion, text, original_audio_file_path, transfer_strength, seed=42, duration=10, batchsize=1, guidance_scale=2.5, ddim_steps=200, config=None, ): if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") assert original_audio_file_path is not None, "You need to provide the original audio file path" audio_file_duration = get_duration(original_audio_file_path) assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path # if(duration > 20): # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds") # duration = 20 if(duration >= audio_file_duration): print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) duration = round_up_duration(audio_file_duration) print("Set new duration as %s-seconds" % duration) # duration = round_up_duration(duration) latent_diffusion = set_cond_text(latent_diffusion) if config is not None: assert type(config) is str config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) else: config = default_audioldm_config() seed_everything(int(seed)) # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) latent_diffusion.cond_stage_model.embed_mode = "text" fn_STFT = TacotronSTFT( config["preprocessing"]["stft"]["filter_length"], config["preprocessing"]["stft"]["hop_length"], config["preprocessing"]["stft"]["win_length"], config["preprocessing"]["mel"]["n_mel_channels"], config["preprocessing"]["audio"]["sampling_rate"], config["preprocessing"]["mel"]["mel_fmin"], config["preprocessing"]["mel"]["mel_fmax"], ) mel, _, _ = wav_to_fbank( original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT ) mel = mel.unsqueeze(0).unsqueeze(0).to(device) mel = repeat(mel, "1 ... -> b ...", b=batchsize) init_latent = latent_diffusion.get_first_stage_encoding( latent_diffusion.encode_first_stage(mel) ) # move to latent space, encode and sample if(torch.max(torch.abs(init_latent)) > 1e2): init_latent = torch.clip(init_latent, min=-10, max=10) sampler = DDIMSampler(latent_diffusion) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) t_enc = int(transfer_strength * ddim_steps) prompts = text with torch.no_grad(): with autocast("cuda"): with latent_diffusion.ema_scope(): uc = None if guidance_scale != 1.0: uc = latent_diffusion.cond_stage_model.get_unconditional_condition( batchsize ) c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) z_enc = sampler.stochastic_encode( init_latent, torch.tensor([t_enc] * batchsize).to(device) ) samples = sampler.decode( z_enc, c, t_enc, unconditional_guidance_scale=guidance_scale, unconditional_conditioning=uc, ) # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output # print(torch.sum(torch.isnan(samples))) x_samples = latent_diffusion.decode_first_stage(samples) # print(x_samples) x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) # print(x_samples) waveform = latent_diffusion.first_stage_model.decode_to_waveform( x_samples ) return waveform def super_resolution_and_inpainting( latent_diffusion, text, original_audio_file_path = None, seed=42, ddim_steps=200, duration=None, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution config=None, ): seed_everything(int(seed)) if config is not None: assert type(config) is str config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) else: config = default_audioldm_config() fn_STFT = TacotronSTFT( config["preprocessing"]["stft"]["filter_length"], config["preprocessing"]["stft"]["hop_length"], config["preprocessing"]["stft"]["win_length"], config["preprocessing"]["mel"]["n_mel_channels"], config["preprocessing"]["audio"]["sampling_rate"], config["preprocessing"]["mel"]["mel_fmin"], config["preprocessing"]["mel"]["mel_fmax"], ) # waveform = read_wav_file(original_audio_file_path, None) mel, _, _ = wav_to_fbank( original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT ) batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) latent_diffusion = set_cond_text(latent_diffusion) with torch.no_grad(): waveform = latent_diffusion.generate_sample_masked( [batch], unconditional_guidance_scale=guidance_scale, ddim_steps=ddim_steps, n_candidate_gen_per_text=n_candidate_gen_per_text, duration=duration, time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end ) return waveform