primepake
add training flowvae
4f877a2
import copy
import math
import numpy as np
import torch
import torch.nn as nn
from omegaconf import OmegaConf
import models
from models.ldm.vqgan.quantizer import VectorQuantizer
class LDMBase(nn.Module):
def __init__(
self,
encoder,
z_shape,
decoder,
renderer,
encoder_ema_rate=None,
decoder_ema_rate=None,
renderer_ema_rate=None,
z_gaussian=False,
z_gaussian_sample=True,
z_quantizer=False,
z_quantizer_n_embed=8192,
z_quantizer_beta=0.25,
z_layernorm=False,
zaug_p=None,
zaug_tmax=1.0,
zaug_tmax_always=False,
zaug_decoding_loss_type='all',
zaug_zdm_diffusion=None,
gt_noise_lb=None,
drop_z_p=0.0,
zdm_net=None,
zdm_diffusion=None,
zdm_sampler=None,
zdm_n_steps=None,
zdm_ema_rate=0.9999,
zdm_train_normalize=False,
zdm_class_cond=None,
zdm_force_guidance=None,
loss_config=None,
use_ema_encoder=False,
use_ema_decoder=False,
use_ema_renderer=False,
):
print('print all the args ')
print("encoder: ", encoder)
print("z_shape: ",z_shape)
print("decoder: ",decoder)
print("renderer: ",renderer)
print("encoder_ema_rate: ",encoder_ema_rate)
print("decoder_ema_rate: ",decoder_ema_rate)
print("renderer_ema_rate: ",renderer_ema_rate)
print("z_gaussian: ",z_gaussian)
print("z_gaussian_sample: ",z_gaussian_sample)
print("z_quantizer: ",z_quantizer)
print("z_quantizer_n_embed: ",z_quantizer_n_embed)
print("z_quantizer_beta: ",z_quantizer_beta)
print("z_layernorm: ",z_layernorm)
print("zaug_p: ",zaug_p)
print("zaug_tmax: ",zaug_tmax)
print("zaug_tmax_always: ",zaug_tmax_always)
print("zaug_decoding_loss_type: ",zaug_decoding_loss_type)
print("zaug_zdm_diffusion: ",zaug_zdm_diffusion)
print("gt_noise_lb: ",gt_noise_lb)
print("drop_z_p: ",drop_z_p)
print("zdm_net: ",zdm_net)
print("zdm_diffusion: ",zdm_diffusion)
print("zdm_sampler: ",zdm_sampler)
print("zdm_n_steps: ",zdm_n_steps)
print("zdm_ema_rate: ",zdm_ema_rate)
print("zdm_train_normalize: ",zdm_train_normalize)
print("zdm_class_cond: ",zdm_class_cond)
print("zdm_force_guidance: ",zdm_force_guidance)
print("loss_config: ",loss_config)
print("use_ema_encoder: ",use_ema_encoder)
print("use_ema_decoder: ",use_ema_decoder)
print("use_ema_renderer: ",use_ema_renderer)
super().__init__()
self.loss_config = loss_config if loss_config is not None else dict()
self.encoder = models.make(encoder)
self.decoder = models.make(decoder)
self.renderer = models.make(renderer)
self.z_shape = tuple(z_shape)
self.z_gaussian = z_gaussian
self.z_gaussian_sample = z_gaussian_sample
self.z_quantizer = VectorQuantizer(
z_quantizer_n_embed,
z_shape[0],
beta=z_quantizer_beta,
remap=None,
sane_index_shape=False
) if z_quantizer else None
self.z_layernorm = nn.LayerNorm(
list(z_shape),
elementwise_affine=False
) if z_layernorm else None
self.zaug_p = zaug_p
self.zaug_tmax = zaug_tmax
self.zaug_tmax_always = zaug_tmax_always
self.zaug_decoding_loss_type = zaug_decoding_loss_type
if zaug_zdm_diffusion is not None:
self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion)
self.drop_z_p = drop_z_p
if self.drop_z_p > 0:
self.drop_z_emb = nn.Parameter(torch.zeros(z_shape[0], z_shape[1], z_shape[2]), requires_grad=False)
self.gt_noise_lb = gt_noise_lb
# EMA models #
self.encoder_ema_rate = encoder_ema_rate
if self.encoder_ema_rate is not None:
self.encoder_ema = copy.deepcopy(self.encoder)
for p in self.encoder_ema.parameters():
p.requires_grad = False
self.decoder_ema_rate = decoder_ema_rate
if self.decoder_ema_rate is not None:
self.decoder_ema = copy.deepcopy(self.decoder)
for p in self.decoder_ema.parameters():
p.requires_grad = False
self.renderer_ema_rate = renderer_ema_rate
if self.renderer_ema_rate is not None:
self.renderer_ema = copy.deepcopy(self.renderer)
for p in self.renderer_ema.parameters():
p.requires_grad = False
# - #
# z DM #
if zdm_diffusion is not None:
self.zdm_diffusion = models.make(zdm_diffusion)
if OmegaConf.is_config(zdm_sampler):
zdm_sampler = OmegaConf.to_container(zdm_sampler, resolve=True)
zdm_sampler = copy.deepcopy(zdm_sampler)
if zdm_sampler.get('args') is None:
zdm_sampler['args'] = {}
zdm_sampler['args']['diffusion'] = self.zdm_diffusion
self.zdm_sampler = models.make(zdm_sampler)
self.zdm_n_steps = zdm_n_steps
self.zdm_net = models.make(zdm_net)
self.zdm_net_ema = copy.deepcopy(self.zdm_net)
for p in self.zdm_net_ema.parameters():
p.requires_grad = False
self.zdm_ema_rate = zdm_ema_rate
self.zdm_class_cond = zdm_class_cond
self.zdm_force_guidance = zdm_force_guidance
else:
self.zdm_diffusion = None
self.zdm_train_normalize = zdm_train_normalize
if zdm_train_normalize:
self.register_buffer('zdm_Ez_v', torch.tensor(0.))
self.register_buffer('zdm_Ez_n', torch.tensor(0.))
self.register_buffer('zdm_Ez2_v', torch.tensor(0.))
self.register_buffer('zdm_Ez2_n', torch.tensor(0.))
# - #
self.use_ema_encoder = use_ema_encoder
self.use_ema_decoder = use_ema_decoder
self.use_ema_renderer = use_ema_renderer
def get_parameters(self, name):
if name == 'encoder':
return self.encoder.parameters()
elif name == 'decoder':
p = list(self.decoder.parameters())
if self.z_quantizer is not None:
p += list(self.z_quantizer.parameters())
return p
elif name == 'renderer':
return self.renderer.parameters()
elif name == 'zdm':
return self.zdm_net.parameters()
def encode(self, x, return_loss=False, ret=None):
if self.use_ema_encoder:
self.swap_ema_encoder()
z = self.encoder(x)
if self.use_ema_encoder:
self.swap_ema_encoder()
if self.z_gaussian:
print('doing zzzzz_gaussian')
posterior = DiagonalGaussianDistribution(z)
if self.z_gaussian_sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl().mean()
if ret is not None:
ret['z_gau_mean_abs'] = posterior.mean.abs().mean().item()
ret['z_gau_std'] = posterior.std.mean().item()
else:
kl_loss = None
if self.z_layernorm is not None:
z = self.z_layernorm(z)
if (self.zaug_p is not None) and self.training:
assert self.z_layernorm is not None # ensure 0 mean 1 std
if self.zaug_tmax_always:
tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
else:
tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax
zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float()
z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z
self._tz = tz
self._mask_aug = mask_aug
if return_loss:
print('kl_loss', kl_loss)
return z, kl_loss
else:
return z
def decode(self, z, return_loss=False):
if self.z_quantizer is not None:
z, quant_loss, _ = self.z_quantizer(z)
else:
quant_loss = None
if self.use_ema_decoder:
self.swap_ema_decoder()
z_dec = self.decoder(z)
if self.use_ema_decoder:
self.swap_ema_decoder()
if return_loss:
return z_dec, quant_loss
else:
return z_dec
def render(self, z_dec, coord, cell):
raise NotImplementedError
def normalize_for_zdm(self, z):
if self.zdm_train_normalize:
mean = self.zdm_Ez_v
var = self.zdm_Ez2_v - mean ** 2
return (z - mean) / torch.sqrt(var)
else:
return z
def denormalize_for_zdm(self, z):
if self.zdm_train_normalize:
mean = self.zdm_Ez_v
var = self.zdm_Ez2_v - mean ** 2
return z * torch.sqrt(var) + mean
else:
return z
def forward(self, data, mode, has_optimizer=None):
grad = self.get_grad_plan(has_optimizer)
loss = torch.tensor(0., device=data['inp'].device)
loss_config = self.loss_config
ret = dict()
# Encoder
if grad['encoder']:
print('doing kl loss')
z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret)
# if self.z_gaussian:
# print('doing z_gaussian')
# ret['kl_loss'] = kl_loss.item()
# loss = loss + kl_loss * loss_config.get('kl_loss', 0.0)
else:
print('not doing kl loss')
with torch.no_grad():
z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret)
if self.training and self.drop_z_p > 0:
drop_mask = (torch.rand(z.shape[0], device=z.device) < self.drop_z_p).to(z.dtype)
z = drop_mask.view(-1, 1, 1, 1) * self.drop_z_emb.unsqueeze(0) + (1 - drop_mask).view(-1, 1, 1, 1) * z
# Z DM
if grad['zdm']:
print('doing zdm loss')
if self.zdm_train_normalize and self.training:
self.zdm_Ez_v = (
self.zdm_Ez_v * (self.zdm_Ez_n / (self.zdm_Ez_n + 1))
+ z.mean().item() / (self.zdm_Ez_n + 1)
)
self.zdm_Ez_n = self.zdm_Ez_n + 1
self.zdm_Ez2_v = (
self.zdm_Ez2_v * (self.zdm_Ez2_n / (self.zdm_Ez2_n + 1))
+ (z ** 2).mean().item() / (self.zdm_Ez2_n + 1)
)
self.zdm_Ez2_n = self.zdm_Ez2_n + 1
ret['normalize_z_mean'] = self.zdm_Ez_v.item()
ret['normalize_z_std'] = math.sqrt((self.zdm_Ez2_v - self.zdm_Ez_v ** 2).item())
z_for_dm = self.normalize_for_zdm(z)
net_kwargs = dict()
if self.zdm_class_cond is not None:
net_kwargs['class_labels'] = data['class_labels']
zdm_loss = self.zdm_diffusion.loss(self.zdm_net, z_for_dm, net_kwargs=net_kwargs)
ret['zdm_loss'] = zdm_loss.item()
loss = loss + zdm_loss * loss_config.get('zdm_loss', 1.0)
if not self.training:
ret['zdm_ema_loss'] = self.zdm_diffusion.loss(self.zdm_net_ema, z_for_dm, net_kwargs=net_kwargs).item()
# Decoder
if mode == 'z':
print('doing z mode')
ret_z = z
elif mode == 'z_dec':
print('doing z_dec mode')
if grad['decoder']:
print('doing z_dec mode with grad')
z_dec, quant_loss = self.decode(z, return_loss=True)
else:
print('doing z_dec mode without grad')
with torch.no_grad():
z_dec, quant_loss = self.decode(z, return_loss=True)
ret_z = z_dec
# if self.z_quantizer is not None:
# print('doing quant_loss')
# ret['quant_loss'] = quant_loss.item()
# loss = loss + quant_loss * loss_config.get('quant_loss', 1.0)
ret['loss'] = loss
return ret_z, ret
def get_grad_plan(self, has_optimizer):
if has_optimizer is None:
has_optimizer = dict()
grad = dict()
grad['encoder'] = has_optimizer.get('encoder', False)
grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False)
grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False)
grad['zdm'] = has_optimizer.get('zdm', False) # not in chain definition
return grad
def update_ema_fn(self, net_ema, net, rate):
if rate != 1:
for ema_p, cur_p in zip(net_ema.parameters(), net.parameters()):
ema_p.data.lerp_(cur_p.data, 1 - rate)
def update_ema(self):
if self.encoder_ema_rate is not None:
self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate)
if self.decoder_ema_rate is not None:
self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate)
if self.renderer_ema_rate is not None:
self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate)
if (self.zdm_diffusion is not None) and (self.zdm_ema_rate is not None):
self.update_ema_fn(self.zdm_net_ema, self.zdm_net, self.zdm_ema_rate)
def generate_samples(
self,
batch_size,
n_steps,
net_kwargs=None,
uncond_net_kwargs=None,
ema=False,
guidance=1.0,
noise=None,
render_res=(256, 256),
return_z=False,
):
if self.zdm_force_guidance is not None:
guidance = self.zdm_force_guidance
shape = (batch_size,) + self.z_shape
net = self.zdm_net if not ema else self.zdm_net_ema
z = self.zdm_sampler.sample(
net,
shape,
n_steps,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
guidance=guidance,
noise=noise,
)
if return_z:
return z
if (self.zaug_p is not None) and self.zaug_tmax_always:
tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
z, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
z = self.denormalize_for_zdm(z)
z_dec = self.decode(z)
coord = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device)
scale = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device)
return self.render(z_dec, coord, scale)
def swap_ema_encoder(self):
_ = self.encoder
self.encoder = self.encoder_ema
self.encoder_ema = _
def swap_ema_decoder(self):
_ = self.decoder
self.decoder = self.decoder_ema
self.decoder_ema = _
def swap_ema_renderer(self):
_ = self.renderer
self.renderer = self.renderer_ema
self.renderer_ema = _
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class LDMBaseAudio(nn.Module):
def __init__(
self,
encoder,
z_channels,
decoder,
renderer,
zaug_p=0.1,
zaug_tmax=1.0,
zaug_tmax_always=False,
zaug_decoding_loss_type='all',
zaug_zdm_diffusion={'name': 'fm', 'args': {'timescale': 1000.0}},
zdm_ema_rate=0.9999,
loss_config={},
encoder_ema_rate=None,
decoder_ema_rate=None,
renderer_ema_rate=None,
):
super().__init__()
self.loss_config = loss_config
self.encoder = models.make(encoder)
self.decoder = models.make(decoder)
self.renderer = models.make(renderer)
self.z_layernorm = nn.LayerNorm(
z_channels, # e.g., 64
elementwise_affine=False
)
self.zaug_p = zaug_p
self.zaug_tmax = zaug_tmax
self.zaug_tmax_always = zaug_tmax_always
self.zaug_decoding_loss_type = zaug_decoding_loss_type
if zaug_zdm_diffusion is not None:
self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion)
# EMA models #
self.encoder_ema_rate = encoder_ema_rate
if self.encoder_ema_rate is not None:
self.encoder_ema = copy.deepcopy(self.encoder)
for p in self.encoder_ema.parameters():
p.requires_grad = False
self.decoder_ema_rate = decoder_ema_rate
if self.decoder_ema_rate is not None:
self.decoder_ema = copy.deepcopy(self.decoder)
for p in self.decoder_ema.parameters():
p.requires_grad = False
self.renderer_ema_rate = renderer_ema_rate
if self.renderer_ema_rate is not None:
self.renderer_ema = copy.deepcopy(self.renderer)
for p in self.renderer_ema.parameters():
p.requires_grad = False
#
def get_grad_plan(self, has_optimizer):
if has_optimizer is None:
has_optimizer = dict()
grad = dict()
grad['encoder'] = has_optimizer.get('encoder', False)
grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False)
grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False)
return grad
def normalize_latents(self, z):
# z shape: [batch, latent_dim, n_frames] - n_frames can vary!
# print('bef z shape: ', z.shape)
z = z.transpose(-2, -1) # [batch, latent_dim, n_frames]
# print('z shape: ', z.shape)
z = self.z_layernorm(z) # Normalize over latent_dim for each time step
# print('z shape: ', z.shape)
z = z.transpose(-2, -1) # [batch, latent_dim, n_frames]
# print('z shape: ', z.shape)
return z
def update_ema(self):
if self.encoder_ema_rate is not None:
self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate)
if self.decoder_ema_rate is not None:
self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate)
if self.renderer_ema_rate is not None:
self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate)
def get_parameters(self, name):
if name == 'encoder':
return self.encoder.parameters()
elif name == 'decoder':
p = list(self.decoder.parameters())
if self.z_quantizer is not None:
p += list(self.z_quantizer.parameters())
return p
elif name == 'renderer':
return self.renderer.parameters()
elif name == 'zdm':
return self.zdm_net.parameters()
def encode(self, x):
z = self.encoder(x)
# print('z shape: ', z.shape)
z = self.normalize_latents(z)
# print('after norm z shape: ', z.shape)
if (self.zaug_p is not None) and self.training:
assert self.z_layernorm is not None # ensure 0 mean 1 std
if self.zaug_tmax_always:
tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
else:
tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax
zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float()
if z.dim() == 4: # Image: [batch, channels, height, width]
mask_shape = (-1, 1, 1, 1)
elif z.dim() == 3: # Audio: [batch, channels, n_frames]
mask_shape = (-1, 1, 1)
else:
raise ValueError(f"Unsupported tensor dimension: {z.dim()}")
z = mask_aug.view(*mask_shape) * zt + (1 - mask_aug).view(*mask_shape) * z
# z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z
self._tz = tz
self._mask_aug = mask_aug
# print('after zaug z shape: ', z.shape)
return z
def decode(self, z):
z_dec = self.decoder(z)
return z_dec
def render(self, z_dec):
raise NotImplementedError
def forward(self, data, mode, has_optimizer=None):
loss = torch.tensor(0., device=data['inp'].device)
ret = dict()
# print("data['inp'] shape: ", data['inp'].shape)
z = self.encode(data['inp'])
z_dec = self.decode(z)
ret['loss'] = loss
return z_dec, ret
def generate_samples(
self,
batch_size,
n_steps,
net_kwargs=None,
uncond_net_kwargs=None,
ema=False,
guidance=1.0,
noise=None,
return_z=False,
):
if self.zdm_force_guidance is not None:
guidance = self.zdm_force_guidance
shape = (batch_size,) + self.z_shape
net = self.zdm_net if not ema else self.zdm_net_ema
z = self.zdm_sampler.sample(
net,
shape,
n_steps,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
guidance=guidance,
noise=noise,
)
if return_z:
return z
if (self.zaug_p is not None) and self.zaug_tmax_always:
tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
z, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
z = self.denormalize_for_zdm(z)
z_dec = self.decode(z)
return self.render(z_dec)