haoheliu
first commit and add large model
bdab1da
raw
history blame
No virus
23.9 kB
import os
import torch
import numpy as np
from tqdm import tqdm
from audioldm.utils import default, instantiate_from_config, save_wave
from audioldm.latent_diffusion.ddpm import DDPM
from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
from audioldm.latent_diffusion.util import noise_like
from audioldm.latent_diffusion.ddim import DDIMSampler
import os
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LatentDiffusion(DDPM):
"""main class"""
def __init__(
self,
device="cuda",
first_stage_config=None,
cond_stage_config=None,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
base_learning_rate=None,
*args,
**kwargs,
):
self.device = device
self.learning_rate = base_learning_rate
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs["timesteps"]
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = "concat" if concat_mode else "crossattn"
if cond_stage_config == "__is_unconditional__":
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
self.cond_stage_key_orig = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer("scale_factor", torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
def make_cond_schedule(
self,
):
self.cond_ids = torch.full(
size=(self.num_timesteps,),
fill_value=self.num_timesteps - 1,
dtype=torch.long,
)
ids = torch.round(
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
).long()
self.cond_ids[: self.num_timesteps_cond] = ids
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
super().register_schedule(
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != "__is_first_stage__"
assert config != "__is_unconditional__"
model = instantiate_from_config(config)
self.cond_stage_model = model
self.cond_stage_model = self.cond_stage_model.to(self.device)
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, "encode") and callable(
self.cond_stage_model.encode
):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
if len(c) == 1:
c = self.cond_stage_model([c[0], c[0]])
c = c[0:1]
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
@torch.no_grad()
def get_input(
self,
batch,
k,
return_first_stage_encode=True,
return_first_stage_outputs=False,
force_c_encode=False,
cond_key=None,
return_original_cond=False,
bs=None,
):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
if return_first_stage_encode:
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
else:
z = None
if self.model.conditioning_key is not None:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ["caption", "coordinates_bbox"]:
xc = batch[cond_key]
elif cond_key == "class_label":
xc = batch
else:
# [bs, 1, 527]
xc = super().get_input(batch, cond_key)
if type(xc) == torch.Tensor:
xc = xc.to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {"pos_x": pos_x, "pos_y": pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, "b h w c -> b c h w").contiguous()
z = 1.0 / self.scale_factor * z
return self.first_stage_model.decode(z)
def mel_spectrogram_to_waveform(self, mel):
# Mel: [bs, 1, t-steps, fbins]
if len(mel.size()) == 4:
mel = mel.squeeze(1)
mel = mel.permute(0, 2, 1)
waveform = self.first_stage_model.vocoder(mel)
waveform = waveform.cpu().detach().numpy()
return waveform
@torch.no_grad()
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
if self.model.conditioning_key == "concat":
key = "c_concat"
elif self.model.conditioning_key == "crossattn":
key = "c_crossattn"
else:
key = "c_film"
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def p_mean_variance(
self,
x,
c,
t,
clip_denoised: bool,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
score_corrector=None,
corrector_kwargs=None,
):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(
self, model_out, x, t, c, **corrector_kwargs
)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(
self,
x,
c,
t,
clip_denoised=False,
repeat_noise=False,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(
x=x,
c=c,
t=t,
clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (
(1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
)
if return_codebook_ids:
return model_mean + nonzero_mask * (
0.5 * model_log_variance
).exp() * noise, logits.argmax(dim=1)
if return_x0:
return (
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
x0,
)
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(
self,
cond,
shape,
verbose=True,
callback=None,
quantize_denoised=False,
img_callback=None,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
batch_size=None,
x_T=None,
start_T=None,
log_every_t=None,
):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
for key in cond
}
else:
cond = (
[c[:batch_size] for c in cond]
if isinstance(cond, list)
else cond[:batch_size]
)
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = (
tqdm(
reversed(range(0, timesteps)),
desc="Progressive Generation",
total=timesteps,
)
if verbose
else reversed(range(0, timesteps))
)
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != "hybrid"
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(
img,
cond,
ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised,
return_x0=True,
temperature=temperature[i],
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(
self,
cond,
shape,
return_intermediates=False,
x_T=None,
verbose=True,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
start_T=None,
log_every_t=None,
):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = (
tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
if verbose
else reversed(range(0, timesteps))
)
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != "hybrid"
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(
img,
cond,
ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised,
)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(
self,
cond,
batch_size=16,
return_intermediates=False,
x_T=None,
verbose=True,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
shape=None,
**kwargs,
):
if shape is None:
shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
for key in cond
}
else:
cond = (
[c[:batch_size] for c in cond]
if isinstance(cond, list)
else cond[:batch_size]
)
return self.p_sample_loop(
cond,
shape,
return_intermediates=return_intermediates,
x_T=x_T,
verbose=verbose,
timesteps=timesteps,
quantize_denoised=quantize_denoised,
mask=mask,
x0=x0,
**kwargs,
)
@torch.no_grad()
def sample_log(
self,
cond,
batch_size,
ddim,
ddim_steps,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_plms=False,
mask=None,
**kwargs,
):
if mask is not None:
shape = (self.channels, mask.size()[-2], mask.size()[-1])
else:
shape = (self.channels, self.latent_t_size, self.latent_f_size)
intermediate = None
if ddim and not use_plms:
# print("Use ddim sampler")
ddim_sampler = DDIMSampler(self)
samples, intermediates = ddim_sampler.sample(
ddim_steps,
batch_size,
shape,
cond,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
**kwargs,
)
else:
# print("Use DDPM sampler")
samples, intermediates = self.sample(
cond=cond,
batch_size=batch_size,
return_intermediates=True,
unconditional_guidance_scale=unconditional_guidance_scale,
mask=mask,
unconditional_conditioning=unconditional_conditioning,
**kwargs,
)
return samples, intermediate
@torch.no_grad()
def generate_sample(
self,
batchs,
ddim_steps=200,
ddim_eta=1.0,
x_T=None,
n_gen=1,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
name="waveform",
use_plms=False,
save=False,
**kwargs,
):
# Generate n_gen times and select the best
# Batch: audio, text, fnames
assert x_T is None
try:
batchs = iter(batchs)
except TypeError:
raise ValueError("The first input argument should be an iterable object")
if use_plms:
assert ddim_steps is not None
use_ddim = ddim_steps is not None
# waveform_save_path = os.path.join(self.get_log_dir(), name)
# os.makedirs(waveform_save_path, exist_ok=True)
# print("Waveform save path: ", waveform_save_path)
with self.ema_scope("Plotting"):
for batch in batchs:
z, c = self.get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=False,
force_c_encode=True,
return_original_cond=False,
bs=None,
)
text = super().get_input(batch, "text")
# Generate multiple samples
batch_size = z.shape[0] * n_gen
c = torch.cat([c] * n_gen, dim=0)
text = text * n_gen
if unconditional_guidance_scale != 1.0:
unconditional_conditioning = (
self.cond_stage_model.get_unconditional_condition(batch_size)
)
fnames = list(super().get_input(batch, "fname"))
samples, _ = self.sample_log(
cond=c,
batch_size=batch_size,
x_T=x_T,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
use_plms=use_plms,
)
mel = self.decode_first_stage(samples)
waveform = self.mel_spectrogram_to_waveform(mel)
similarity = self.cond_stage_model.cos_similarity(
torch.FloatTensor(waveform).squeeze(1), text
)
best_index = []
for i in range(z.shape[0]):
candidates = similarity[i :: z.shape[0]]
max_index = torch.argmax(candidates).item()
best_index.append(i + max_index * z.shape[0])
waveform = waveform[best_index]
# print("Similarity between generated audio and text", similarity)
# print("Choose the following indexes:", best_index)
return waveform