EditordeAudio / audioldm /pipeline.py
hilamanor's picture
initial commit
e73da9c
import os
import argparse
import yaml
import torch
from torch import autocast
from tqdm import tqdm, trange
from audioldm import LatentDiffusion, seed_everything
from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint
from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file
from audioldm.latent_diffusion.ddim import DDIMSampler
from einops import repeat
import os
def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1):
text = [text] * batchsize
if batchsize < 1:
print("Warning: Batchsize must be at least 1. Batchsize is set to .")
if(fbank is None):
fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
else:
fbank = torch.FloatTensor(fbank)
fbank = fbank.expand(batchsize, 1024, 64)
assert fbank.size(0) == batchsize
stft = torch.zeros((batchsize, 1024, 512)) # Not used
if(waveform is None):
waveform = torch.zeros((batchsize, 160000)) # Not used
else:
waveform = torch.FloatTensor(waveform)
waveform = waveform.expand(batchsize, -1)
assert waveform.size(0) == batchsize
fname = [""] * batchsize # Not used
batch = (
fbank,
stft,
None,
fname,
waveform,
text,
)
return batch
def round_up_duration(duration):
return int(round(duration/2.5) + 1) * 2.5
def build_model(
ckpt_path=None,
config=None,
model_name="audioldm-s-full"
):
print("Load AudioLDM: %s", model_name)
if(ckpt_path is None):
ckpt_path = get_metadata()[model_name]["path"]
if(not os.path.exists(ckpt_path)):
download_checkpoint(model_name)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
if config is not None:
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config(model_name)
# Use text as condition instead of using waveform during training
config["model"]["params"]["device"] = device
config["model"]["params"]["cond_stage_key"] = "text"
# No normalization here
latent_diffusion = LatentDiffusion(**config["model"]["params"])
resume_from_checkpoint = ckpt_path
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
latent_diffusion.load_state_dict(checkpoint["state_dict"])
latent_diffusion.eval()
latent_diffusion = latent_diffusion.to(device)
latent_diffusion.cond_stage_model.embed_mode = "text"
return latent_diffusion
def duration_to_latent_t_size(duration):
return int(duration * 25.6)
def set_cond_audio(latent_diffusion):
latent_diffusion.cond_stage_key = "waveform"
latent_diffusion.cond_stage_model.embed_mode="audio"
return latent_diffusion
def set_cond_text(latent_diffusion):
latent_diffusion.cond_stage_key = "text"
latent_diffusion.cond_stage_model.embed_mode="text"
return latent_diffusion
def text_to_audio(
latent_diffusion,
text,
original_audio_file_path = None,
seed=42,
ddim_steps=200,
duration=10,
batchsize=1,
guidance_scale=2.5,
n_candidate_gen_per_text=3,
config=None,
):
seed_everything(int(seed))
waveform = None
if(original_audio_file_path is not None):
waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160)
batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
if(waveform is not None):
print("Generate audio that has similar content as %s" % original_audio_file_path)
latent_diffusion = set_cond_audio(latent_diffusion)
else:
print("Generate audio using text %s" % text)
latent_diffusion = set_cond_text(latent_diffusion)
with torch.no_grad():
waveform = latent_diffusion.generate_sample(
[batch],
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
duration=duration,
)
return waveform
def style_transfer(
latent_diffusion,
text,
original_audio_file_path,
transfer_strength,
seed=42,
duration=10,
batchsize=1,
guidance_scale=2.5,
ddim_steps=200,
config=None,
):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
assert original_audio_file_path is not None, "You need to provide the original audio file path"
audio_file_duration = get_duration(original_audio_file_path)
assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path
# if(duration > 20):
# print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds")
# duration = 20
if(duration >= audio_file_duration):
print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration))
duration = round_up_duration(audio_file_duration)
print("Set new duration as %s-seconds" % duration)
# duration = round_up_duration(duration)
latent_diffusion = set_cond_text(latent_diffusion)
if config is not None:
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config()
seed_everything(int(seed))
# latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
latent_diffusion.cond_stage_model.embed_mode = "text"
fn_STFT = TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
mel, _, _ = wav_to_fbank(
original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
)
mel = mel.unsqueeze(0).unsqueeze(0).to(device)
mel = repeat(mel, "1 ... -> b ...", b=batchsize)
init_latent = latent_diffusion.get_first_stage_encoding(
latent_diffusion.encode_first_stage(mel)
) # move to latent space, encode and sample
if(torch.max(torch.abs(init_latent)) > 1e2):
init_latent = torch.clip(init_latent, min=-10, max=10)
sampler = DDIMSampler(latent_diffusion)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False)
t_enc = int(transfer_strength * ddim_steps)
prompts = text
with torch.no_grad():
with autocast("cuda"):
with latent_diffusion.ema_scope():
uc = None
if guidance_scale != 1.0:
uc = latent_diffusion.cond_stage_model.get_unconditional_condition(
batchsize
)
c = latent_diffusion.get_learned_conditioning([prompts] * batchsize)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc] * batchsize).to(device)
)
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=guidance_scale,
unconditional_conditioning=uc,
)
# x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output
# print(torch.sum(torch.isnan(samples)))
x_samples = latent_diffusion.decode_first_stage(samples)
# print(x_samples)
x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:])
# print(x_samples)
waveform = latent_diffusion.first_stage_model.decode_to_waveform(
x_samples
)
return waveform
def super_resolution_and_inpainting(
latent_diffusion,
text,
original_audio_file_path = None,
seed=42,
ddim_steps=200,
duration=None,
batchsize=1,
guidance_scale=2.5,
n_candidate_gen_per_text=3,
time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram
# time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting
# freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins
freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution
config=None,
):
seed_everything(int(seed))
if config is not None:
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config()
fn_STFT = TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)
# waveform = read_wav_file(original_audio_file_path, None)
mel, _, _ = wav_to_fbank(
original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
)
batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize)
# latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
latent_diffusion = set_cond_text(latent_diffusion)
with torch.no_grad():
waveform = latent_diffusion.generate_sample_masked(
[batch],
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
duration=duration,
time_mask_ratio_start_and_end=time_mask_ratio_start_and_end,
freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end
)
return waveform