primepake
add training flowvae
4f877a2
import copy
import math
import torch
import models
from omegaconf import OmegaConf
from models import register
from models.ldm.ldm_base import LDMBase, LDMBaseAudio
from models.ldm.vqgan.lpips import LPIPS
@register('dito')
class DiTo(LDMBase):
def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1, lpips=False, **kwargs):
super().__init__(**kwargs)
self.render_diffusion = models.make(render_diffusion)
if OmegaConf.is_config(render_sampler):
render_sampler = OmegaConf.to_container(render_sampler, resolve=True)
render_sampler = copy.deepcopy(render_sampler)
if render_sampler.get('args') is None:
render_sampler['args'] = {}
render_sampler['args']['diffusion'] = self.render_diffusion
self.render_sampler = models.make(render_sampler)
self.render_n_steps = render_n_steps
self.renderer_guidance = renderer_guidance
self.t_loss_monitor_v = [0 for _ in range(10)]
self.t_loss_monitor_n = [0 for _ in range(10)]
self.t_loss_monitor_decay = 0.99
self.use_lpips = lpips
if lpips:
self.lpips_loss = LPIPS().eval()
def render(self, z_dec, coord, scale):
shape = (coord.size(0), 3, coord.size(2), coord.size(3))
net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': z_dec}
if self.use_ema_renderer:
self.swap_ema_renderer()
if self.renderer_guidance > 1:
uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1)
uncond_net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': uncond_z_dec}
else:
uncond_net_kwargs = None
ret = self.render_sampler.sample(
net=self.renderer,
shape=shape,
n_steps=self.render_n_steps,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
guidance=self.renderer_guidance,
)
if self.use_ema_renderer:
self.swap_ema_renderer()
return ret
def forward(self, data, mode, has_optimizer=None):
if mode in ['z', 'z_dec']:
ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer)
return ret_z
grad = self.get_grad_plan(has_optimizer)
loss_config = self.loss_config
print('mode', mode)
if mode == 'pred':
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
gt_patch = data['gt'][:, :3, ...]
coord = data['gt'][:, 3:5, ...]
scale = data['gt'][:, 5:7, ...]
if grad['renderer']:
return self.render(z_dec, coord, scale)
else:
with torch.no_grad():
return self.render(z_dec, coord, scale)
elif mode == 'loss':
if not grad['renderer']: # Only training zdm
print('not grad[renderer]')
_, ret = super().forward(data, mode='z', has_optimizer=has_optimizer)
return ret
gt_patch = data['gt'][:, :3, ...]
coord = data['gt'][:, 3:5, ...]
scale = data['gt'][:, 5:7, ...]
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
net_kwargs = {'z_dec': z_dec}
print('latent z_dec shape: ', z_dec.shape)
t = torch.rand(gt_patch.shape[0], device=gt_patch.device)
print('self.gt_noise_lb:', self.gt_noise_lb)
if self.gt_noise_lb is not None:
tmin = torch.ones_like(t) * self.gt_noise_lb
tmax = torch.ones_like(t) * 1
t = tmin + (tmax - tmin) * torch.rand_like(tmin)
print('self.zaug_p:', self.zaug_p)
print('self.training:', self.training)
if (self.zaug_p is not None) and self.training:
tz = self._tz
mask_aug = self._mask_aug
typ = self.zaug_decoding_loss_type
if typ == 'all':
tmin = torch.ones_like(tz) * 0
tmax = torch.ones_like(tz) * 1
elif typ == 'suffix':
tmin = tz
tmax = torch.ones_like(tz) * 1
elif typ == 'tz':
tmin = tz
tmax = tz
elif typ == 'tmax':
tmin = torch.ones_like(tz) * 1
tmax = torch.ones_like(tz) * 1
else:
raise NotImplementedError
t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin)
t = mask_aug * t_aug + (1 - mask_aug) * t
print('self.use_lpips:', self.use_lpips)
if not self.use_lpips:
loss, t = self.render_diffusion.loss(
net=self.renderer,
x=gt_patch,
t=t,
net_kwargs=net_kwargs,
return_loss_unreduced=True
)
else:
loss, t, x_t, pred = self.render_diffusion.loss(
net=self.renderer,
x=gt_patch,
t=t,
net_kwargs=net_kwargs,
return_loss_unreduced=True,
return_all=True
)
sample_pred = x_t + t.view(-1, 1, 1, 1) * pred
lpips_loss = self.lpips_loss(sample_pred, gt_patch).mean()
ret['lpips_loss'] = lpips_loss.item()
lpips_loss_w = loss_config.get('lpips_loss', 1)
ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w
# Visualize diffusion network loss for different timesteps #
if self.training:
m = len(self.t_loss_monitor_v)
for i in range(len(loss)):
q = min(math.floor(t[i].item() * m), m - 1)
self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay)
self.t_loss_monitor_n[q] += 1
for q in range(m):
if self.t_loss_monitor_n[q] > 0:
if self.t_loss_monitor_n[q] < 500:
r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q])
else:
r = 1
ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r
# - #
dae_loss = loss.mean()
ret['dae_loss'] = dae_loss.item()
dae_loss_w = loss_config.get('dae_loss', 1)
ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
return ret
@register('dito_audio')
class DiToAudio(LDMBaseAudio):
def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1,**kwargs):
super().__init__(**kwargs)
self.render_diffusion = models.make(render_diffusion)
if OmegaConf.is_config(render_sampler):
render_sampler = OmegaConf.to_container(render_sampler, resolve=True)
render_sampler = copy.deepcopy(render_sampler)
if render_sampler.get('args') is None:
render_sampler['args'] = {}
render_sampler['args']['diffusion'] = self.render_diffusion
self.render_sampler = models.make(render_sampler)
self.render_n_steps = render_n_steps
self.renderer_guidance = renderer_guidance
self.t_loss_monitor_v = [0 for _ in range(10)]
self.t_loss_monitor_n = [0 for _ in range(10)]
self.t_loss_monitor_decay = 0.99
def render(self, z_dec):
net_kwargs = {'z_dec': z_dec}
n_frames = z_dec.size(2) * 320
shape = (z_dec.size(0), z_dec.size(0), n_frames)
if self.renderer_guidance > 1:
uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1)
uncond_net_kwargs = {'z_dec': uncond_z_dec}
else:
uncond_net_kwargs = None
ret = self.render_sampler.sample(
net=self.renderer,
n_steps=self.render_n_steps,
shape=shape,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
guidance=self.renderer_guidance,
)
# if self.use_ema_renderer:
# self.swap_ema_renderer()
return ret
def forward(self, data, mode, has_optimizer=None):
if mode in ['z', 'z_dec']:
ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer)
return ret_z
grad = self.get_grad_plan(has_optimizer)
loss_config = self.loss_config
if mode == 'pred':
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
gt_patch = data['gt']
if grad['renderer']:
return self.render(z_dec)
else:
with torch.no_grad():
return self.render(z_dec)
elif mode == 'loss':
if not grad['renderer']: # Only training zdm
_, ret = super().forward(data, mode='z', has_optimizer=has_optimizer)
return ret
gt_patch = data['gt']
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
net_kwargs = {'z_dec': z_dec}
# print('latent z_dec shape: ', z_dec.shape)
t = torch.rand(gt_patch.shape[0], device=gt_patch.device)
# print('self.zaug_p:', self.zaug_p)
# print('self.training:', self.training)
if (self.zaug_p is not None) and self.training:
tz = self._tz
mask_aug = self._mask_aug
typ = self.zaug_decoding_loss_type
if typ == 'all':
tmin = torch.ones_like(tz) * 0
tmax = torch.ones_like(tz) * 1
elif typ == 'suffix':
tmin = tz
tmax = torch.ones_like(tz) * 1
elif typ == 'tz':
tmin = tz
tmax = tz
elif typ == 'tmax':
tmin = torch.ones_like(tz) * 1
tmax = torch.ones_like(tz) * 1
else:
raise NotImplementedError
t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin)
t = mask_aug * t_aug + (1 - mask_aug) * t
loss, t = self.render_diffusion.loss(
net=self.renderer,
x=gt_patch,
t=t,
net_kwargs=net_kwargs,
return_loss_unreduced=True
)
# Visualize diffusion network loss for different timesteps #
if self.training:
m = len(self.t_loss_monitor_v)
for i in range(len(loss)):
q = min(math.floor(t[i].item() * m), m - 1)
self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay)
self.t_loss_monitor_n[q] += 1
for q in range(m):
if self.t_loss_monitor_n[q] > 0:
if self.t_loss_monitor_n[q] < 500:
r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q])
else:
r = 1
ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r
# - #
dae_loss = loss.mean()
ret['dae_loss'] = dae_loss.item()
dae_loss_w = loss_config.get('dae_loss', 1)
ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
return ret