Spaces:
Sleeping
Sleeping
| import copy | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from omegaconf import OmegaConf | |
| import models | |
| from models.ldm.vqgan.quantizer import VectorQuantizer | |
| class LDMBase(nn.Module): | |
| def __init__( | |
| self, | |
| encoder, | |
| z_shape, | |
| decoder, | |
| renderer, | |
| encoder_ema_rate=None, | |
| decoder_ema_rate=None, | |
| renderer_ema_rate=None, | |
| z_gaussian=False, | |
| z_gaussian_sample=True, | |
| z_quantizer=False, | |
| z_quantizer_n_embed=8192, | |
| z_quantizer_beta=0.25, | |
| z_layernorm=False, | |
| zaug_p=None, | |
| zaug_tmax=1.0, | |
| zaug_tmax_always=False, | |
| zaug_decoding_loss_type='all', | |
| zaug_zdm_diffusion=None, | |
| gt_noise_lb=None, | |
| drop_z_p=0.0, | |
| zdm_net=None, | |
| zdm_diffusion=None, | |
| zdm_sampler=None, | |
| zdm_n_steps=None, | |
| zdm_ema_rate=0.9999, | |
| zdm_train_normalize=False, | |
| zdm_class_cond=None, | |
| zdm_force_guidance=None, | |
| loss_config=None, | |
| use_ema_encoder=False, | |
| use_ema_decoder=False, | |
| use_ema_renderer=False, | |
| ): | |
| print('print all the args ') | |
| print("encoder: ", encoder) | |
| print("z_shape: ",z_shape) | |
| print("decoder: ",decoder) | |
| print("renderer: ",renderer) | |
| print("encoder_ema_rate: ",encoder_ema_rate) | |
| print("decoder_ema_rate: ",decoder_ema_rate) | |
| print("renderer_ema_rate: ",renderer_ema_rate) | |
| print("z_gaussian: ",z_gaussian) | |
| print("z_gaussian_sample: ",z_gaussian_sample) | |
| print("z_quantizer: ",z_quantizer) | |
| print("z_quantizer_n_embed: ",z_quantizer_n_embed) | |
| print("z_quantizer_beta: ",z_quantizer_beta) | |
| print("z_layernorm: ",z_layernorm) | |
| print("zaug_p: ",zaug_p) | |
| print("zaug_tmax: ",zaug_tmax) | |
| print("zaug_tmax_always: ",zaug_tmax_always) | |
| print("zaug_decoding_loss_type: ",zaug_decoding_loss_type) | |
| print("zaug_zdm_diffusion: ",zaug_zdm_diffusion) | |
| print("gt_noise_lb: ",gt_noise_lb) | |
| print("drop_z_p: ",drop_z_p) | |
| print("zdm_net: ",zdm_net) | |
| print("zdm_diffusion: ",zdm_diffusion) | |
| print("zdm_sampler: ",zdm_sampler) | |
| print("zdm_n_steps: ",zdm_n_steps) | |
| print("zdm_ema_rate: ",zdm_ema_rate) | |
| print("zdm_train_normalize: ",zdm_train_normalize) | |
| print("zdm_class_cond: ",zdm_class_cond) | |
| print("zdm_force_guidance: ",zdm_force_guidance) | |
| print("loss_config: ",loss_config) | |
| print("use_ema_encoder: ",use_ema_encoder) | |
| print("use_ema_decoder: ",use_ema_decoder) | |
| print("use_ema_renderer: ",use_ema_renderer) | |
| super().__init__() | |
| self.loss_config = loss_config if loss_config is not None else dict() | |
| self.encoder = models.make(encoder) | |
| self.decoder = models.make(decoder) | |
| self.renderer = models.make(renderer) | |
| self.z_shape = tuple(z_shape) | |
| self.z_gaussian = z_gaussian | |
| self.z_gaussian_sample = z_gaussian_sample | |
| self.z_quantizer = VectorQuantizer( | |
| z_quantizer_n_embed, | |
| z_shape[0], | |
| beta=z_quantizer_beta, | |
| remap=None, | |
| sane_index_shape=False | |
| ) if z_quantizer else None | |
| self.z_layernorm = nn.LayerNorm( | |
| list(z_shape), | |
| elementwise_affine=False | |
| ) if z_layernorm else None | |
| self.zaug_p = zaug_p | |
| self.zaug_tmax = zaug_tmax | |
| self.zaug_tmax_always = zaug_tmax_always | |
| self.zaug_decoding_loss_type = zaug_decoding_loss_type | |
| if zaug_zdm_diffusion is not None: | |
| self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion) | |
| self.drop_z_p = drop_z_p | |
| if self.drop_z_p > 0: | |
| self.drop_z_emb = nn.Parameter(torch.zeros(z_shape[0], z_shape[1], z_shape[2]), requires_grad=False) | |
| self.gt_noise_lb = gt_noise_lb | |
| # EMA models # | |
| self.encoder_ema_rate = encoder_ema_rate | |
| if self.encoder_ema_rate is not None: | |
| self.encoder_ema = copy.deepcopy(self.encoder) | |
| for p in self.encoder_ema.parameters(): | |
| p.requires_grad = False | |
| self.decoder_ema_rate = decoder_ema_rate | |
| if self.decoder_ema_rate is not None: | |
| self.decoder_ema = copy.deepcopy(self.decoder) | |
| for p in self.decoder_ema.parameters(): | |
| p.requires_grad = False | |
| self.renderer_ema_rate = renderer_ema_rate | |
| if self.renderer_ema_rate is not None: | |
| self.renderer_ema = copy.deepcopy(self.renderer) | |
| for p in self.renderer_ema.parameters(): | |
| p.requires_grad = False | |
| # - # | |
| # z DM # | |
| if zdm_diffusion is not None: | |
| self.zdm_diffusion = models.make(zdm_diffusion) | |
| if OmegaConf.is_config(zdm_sampler): | |
| zdm_sampler = OmegaConf.to_container(zdm_sampler, resolve=True) | |
| zdm_sampler = copy.deepcopy(zdm_sampler) | |
| if zdm_sampler.get('args') is None: | |
| zdm_sampler['args'] = {} | |
| zdm_sampler['args']['diffusion'] = self.zdm_diffusion | |
| self.zdm_sampler = models.make(zdm_sampler) | |
| self.zdm_n_steps = zdm_n_steps | |
| self.zdm_net = models.make(zdm_net) | |
| self.zdm_net_ema = copy.deepcopy(self.zdm_net) | |
| for p in self.zdm_net_ema.parameters(): | |
| p.requires_grad = False | |
| self.zdm_ema_rate = zdm_ema_rate | |
| self.zdm_class_cond = zdm_class_cond | |
| self.zdm_force_guidance = zdm_force_guidance | |
| else: | |
| self.zdm_diffusion = None | |
| self.zdm_train_normalize = zdm_train_normalize | |
| if zdm_train_normalize: | |
| self.register_buffer('zdm_Ez_v', torch.tensor(0.)) | |
| self.register_buffer('zdm_Ez_n', torch.tensor(0.)) | |
| self.register_buffer('zdm_Ez2_v', torch.tensor(0.)) | |
| self.register_buffer('zdm_Ez2_n', torch.tensor(0.)) | |
| # - # | |
| self.use_ema_encoder = use_ema_encoder | |
| self.use_ema_decoder = use_ema_decoder | |
| self.use_ema_renderer = use_ema_renderer | |
| def get_parameters(self, name): | |
| if name == 'encoder': | |
| return self.encoder.parameters() | |
| elif name == 'decoder': | |
| p = list(self.decoder.parameters()) | |
| if self.z_quantizer is not None: | |
| p += list(self.z_quantizer.parameters()) | |
| return p | |
| elif name == 'renderer': | |
| return self.renderer.parameters() | |
| elif name == 'zdm': | |
| return self.zdm_net.parameters() | |
| def encode(self, x, return_loss=False, ret=None): | |
| if self.use_ema_encoder: | |
| self.swap_ema_encoder() | |
| z = self.encoder(x) | |
| if self.use_ema_encoder: | |
| self.swap_ema_encoder() | |
| if self.z_gaussian: | |
| print('doing zzzzz_gaussian') | |
| posterior = DiagonalGaussianDistribution(z) | |
| if self.z_gaussian_sample: | |
| z = posterior.sample() | |
| else: | |
| z = posterior.mode() | |
| kl_loss = posterior.kl().mean() | |
| if ret is not None: | |
| ret['z_gau_mean_abs'] = posterior.mean.abs().mean().item() | |
| ret['z_gau_std'] = posterior.std.mean().item() | |
| else: | |
| kl_loss = None | |
| if self.z_layernorm is not None: | |
| z = self.z_layernorm(z) | |
| if (self.zaug_p is not None) and self.training: | |
| assert self.z_layernorm is not None # ensure 0 mean 1 std | |
| if self.zaug_tmax_always: | |
| tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax | |
| else: | |
| tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax | |
| zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz) | |
| mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float() | |
| z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z | |
| self._tz = tz | |
| self._mask_aug = mask_aug | |
| if return_loss: | |
| print('kl_loss', kl_loss) | |
| return z, kl_loss | |
| else: | |
| return z | |
| def decode(self, z, return_loss=False): | |
| if self.z_quantizer is not None: | |
| z, quant_loss, _ = self.z_quantizer(z) | |
| else: | |
| quant_loss = None | |
| if self.use_ema_decoder: | |
| self.swap_ema_decoder() | |
| z_dec = self.decoder(z) | |
| if self.use_ema_decoder: | |
| self.swap_ema_decoder() | |
| if return_loss: | |
| return z_dec, quant_loss | |
| else: | |
| return z_dec | |
| def render(self, z_dec, coord, cell): | |
| raise NotImplementedError | |
| def normalize_for_zdm(self, z): | |
| if self.zdm_train_normalize: | |
| mean = self.zdm_Ez_v | |
| var = self.zdm_Ez2_v - mean ** 2 | |
| return (z - mean) / torch.sqrt(var) | |
| else: | |
| return z | |
| def denormalize_for_zdm(self, z): | |
| if self.zdm_train_normalize: | |
| mean = self.zdm_Ez_v | |
| var = self.zdm_Ez2_v - mean ** 2 | |
| return z * torch.sqrt(var) + mean | |
| else: | |
| return z | |
| def forward(self, data, mode, has_optimizer=None): | |
| grad = self.get_grad_plan(has_optimizer) | |
| loss = torch.tensor(0., device=data['inp'].device) | |
| loss_config = self.loss_config | |
| ret = dict() | |
| # Encoder | |
| if grad['encoder']: | |
| print('doing kl loss') | |
| z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret) | |
| # if self.z_gaussian: | |
| # print('doing z_gaussian') | |
| # ret['kl_loss'] = kl_loss.item() | |
| # loss = loss + kl_loss * loss_config.get('kl_loss', 0.0) | |
| else: | |
| print('not doing kl loss') | |
| with torch.no_grad(): | |
| z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret) | |
| if self.training and self.drop_z_p > 0: | |
| drop_mask = (torch.rand(z.shape[0], device=z.device) < self.drop_z_p).to(z.dtype) | |
| z = drop_mask.view(-1, 1, 1, 1) * self.drop_z_emb.unsqueeze(0) + (1 - drop_mask).view(-1, 1, 1, 1) * z | |
| # Z DM | |
| if grad['zdm']: | |
| print('doing zdm loss') | |
| if self.zdm_train_normalize and self.training: | |
| self.zdm_Ez_v = ( | |
| self.zdm_Ez_v * (self.zdm_Ez_n / (self.zdm_Ez_n + 1)) | |
| + z.mean().item() / (self.zdm_Ez_n + 1) | |
| ) | |
| self.zdm_Ez_n = self.zdm_Ez_n + 1 | |
| self.zdm_Ez2_v = ( | |
| self.zdm_Ez2_v * (self.zdm_Ez2_n / (self.zdm_Ez2_n + 1)) | |
| + (z ** 2).mean().item() / (self.zdm_Ez2_n + 1) | |
| ) | |
| self.zdm_Ez2_n = self.zdm_Ez2_n + 1 | |
| ret['normalize_z_mean'] = self.zdm_Ez_v.item() | |
| ret['normalize_z_std'] = math.sqrt((self.zdm_Ez2_v - self.zdm_Ez_v ** 2).item()) | |
| z_for_dm = self.normalize_for_zdm(z) | |
| net_kwargs = dict() | |
| if self.zdm_class_cond is not None: | |
| net_kwargs['class_labels'] = data['class_labels'] | |
| zdm_loss = self.zdm_diffusion.loss(self.zdm_net, z_for_dm, net_kwargs=net_kwargs) | |
| ret['zdm_loss'] = zdm_loss.item() | |
| loss = loss + zdm_loss * loss_config.get('zdm_loss', 1.0) | |
| if not self.training: | |
| ret['zdm_ema_loss'] = self.zdm_diffusion.loss(self.zdm_net_ema, z_for_dm, net_kwargs=net_kwargs).item() | |
| # Decoder | |
| if mode == 'z': | |
| print('doing z mode') | |
| ret_z = z | |
| elif mode == 'z_dec': | |
| print('doing z_dec mode') | |
| if grad['decoder']: | |
| print('doing z_dec mode with grad') | |
| z_dec, quant_loss = self.decode(z, return_loss=True) | |
| else: | |
| print('doing z_dec mode without grad') | |
| with torch.no_grad(): | |
| z_dec, quant_loss = self.decode(z, return_loss=True) | |
| ret_z = z_dec | |
| # if self.z_quantizer is not None: | |
| # print('doing quant_loss') | |
| # ret['quant_loss'] = quant_loss.item() | |
| # loss = loss + quant_loss * loss_config.get('quant_loss', 1.0) | |
| ret['loss'] = loss | |
| return ret_z, ret | |
| def get_grad_plan(self, has_optimizer): | |
| if has_optimizer is None: | |
| has_optimizer = dict() | |
| grad = dict() | |
| grad['encoder'] = has_optimizer.get('encoder', False) | |
| grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False) | |
| grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False) | |
| grad['zdm'] = has_optimizer.get('zdm', False) # not in chain definition | |
| return grad | |
| def update_ema_fn(self, net_ema, net, rate): | |
| if rate != 1: | |
| for ema_p, cur_p in zip(net_ema.parameters(), net.parameters()): | |
| ema_p.data.lerp_(cur_p.data, 1 - rate) | |
| def update_ema(self): | |
| if self.encoder_ema_rate is not None: | |
| self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate) | |
| if self.decoder_ema_rate is not None: | |
| self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate) | |
| if self.renderer_ema_rate is not None: | |
| self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate) | |
| if (self.zdm_diffusion is not None) and (self.zdm_ema_rate is not None): | |
| self.update_ema_fn(self.zdm_net_ema, self.zdm_net, self.zdm_ema_rate) | |
| def generate_samples( | |
| self, | |
| batch_size, | |
| n_steps, | |
| net_kwargs=None, | |
| uncond_net_kwargs=None, | |
| ema=False, | |
| guidance=1.0, | |
| noise=None, | |
| render_res=(256, 256), | |
| return_z=False, | |
| ): | |
| if self.zdm_force_guidance is not None: | |
| guidance = self.zdm_force_guidance | |
| shape = (batch_size,) + self.z_shape | |
| net = self.zdm_net if not ema else self.zdm_net_ema | |
| z = self.zdm_sampler.sample( | |
| net, | |
| shape, | |
| n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| guidance=guidance, | |
| noise=noise, | |
| ) | |
| if return_z: | |
| return z | |
| if (self.zaug_p is not None) and self.zaug_tmax_always: | |
| tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax | |
| z, _ = self.zaug_zdm_diffusion.add_noise(z, tz) | |
| z = self.denormalize_for_zdm(z) | |
| z_dec = self.decode(z) | |
| coord = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device) | |
| scale = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device) | |
| return self.render(z_dec, coord, scale) | |
| def swap_ema_encoder(self): | |
| _ = self.encoder | |
| self.encoder = self.encoder_ema | |
| self.encoder_ema = _ | |
| def swap_ema_decoder(self): | |
| _ = self.decoder | |
| self.decoder = self.decoder_ema | |
| self.decoder_ema = _ | |
| def swap_ema_renderer(self): | |
| _ = self.renderer | |
| self.renderer = self.renderer_ema | |
| self.renderer_ema = _ | |
| class DiagonalGaussianDistribution(object): | |
| def __init__(self, parameters, deterministic=False): | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) | |
| def sample(self): | |
| x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) | |
| return x | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return torch.Tensor([0.]) | |
| else: | |
| if other is None: | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean, 2) | |
| + self.var - 1.0 - self.logvar, | |
| dim=[1, 2, 3]) | |
| else: | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var - 1.0 - self.logvar + other.logvar, | |
| dim=[1, 2, 3]) | |
| def nll(self, sample, dims=[1,2,3]): | |
| if self.deterministic: | |
| return torch.Tensor([0.]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum( | |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | |
| dim=dims) | |
| def mode(self): | |
| return self.mean | |
| class LDMBaseAudio(nn.Module): | |
| def __init__( | |
| self, | |
| encoder, | |
| z_channels, | |
| decoder, | |
| renderer, | |
| zaug_p=0.1, | |
| zaug_tmax=1.0, | |
| zaug_tmax_always=False, | |
| zaug_decoding_loss_type='all', | |
| zaug_zdm_diffusion={'name': 'fm', 'args': {'timescale': 1000.0}}, | |
| zdm_ema_rate=0.9999, | |
| loss_config={}, | |
| encoder_ema_rate=None, | |
| decoder_ema_rate=None, | |
| renderer_ema_rate=None, | |
| ): | |
| super().__init__() | |
| self.loss_config = loss_config | |
| self.encoder = models.make(encoder) | |
| self.decoder = models.make(decoder) | |
| self.renderer = models.make(renderer) | |
| self.z_layernorm = nn.LayerNorm( | |
| z_channels, # e.g., 64 | |
| elementwise_affine=False | |
| ) | |
| self.zaug_p = zaug_p | |
| self.zaug_tmax = zaug_tmax | |
| self.zaug_tmax_always = zaug_tmax_always | |
| self.zaug_decoding_loss_type = zaug_decoding_loss_type | |
| if zaug_zdm_diffusion is not None: | |
| self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion) | |
| # EMA models # | |
| self.encoder_ema_rate = encoder_ema_rate | |
| if self.encoder_ema_rate is not None: | |
| self.encoder_ema = copy.deepcopy(self.encoder) | |
| for p in self.encoder_ema.parameters(): | |
| p.requires_grad = False | |
| self.decoder_ema_rate = decoder_ema_rate | |
| if self.decoder_ema_rate is not None: | |
| self.decoder_ema = copy.deepcopy(self.decoder) | |
| for p in self.decoder_ema.parameters(): | |
| p.requires_grad = False | |
| self.renderer_ema_rate = renderer_ema_rate | |
| if self.renderer_ema_rate is not None: | |
| self.renderer_ema = copy.deepcopy(self.renderer) | |
| for p in self.renderer_ema.parameters(): | |
| p.requires_grad = False | |
| # | |
| def get_grad_plan(self, has_optimizer): | |
| if has_optimizer is None: | |
| has_optimizer = dict() | |
| grad = dict() | |
| grad['encoder'] = has_optimizer.get('encoder', False) | |
| grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False) | |
| grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False) | |
| return grad | |
| def normalize_latents(self, z): | |
| # z shape: [batch, latent_dim, n_frames] - n_frames can vary! | |
| # print('bef z shape: ', z.shape) | |
| z = z.transpose(-2, -1) # [batch, latent_dim, n_frames] | |
| # print('z shape: ', z.shape) | |
| z = self.z_layernorm(z) # Normalize over latent_dim for each time step | |
| # print('z shape: ', z.shape) | |
| z = z.transpose(-2, -1) # [batch, latent_dim, n_frames] | |
| # print('z shape: ', z.shape) | |
| return z | |
| def update_ema(self): | |
| if self.encoder_ema_rate is not None: | |
| self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate) | |
| if self.decoder_ema_rate is not None: | |
| self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate) | |
| if self.renderer_ema_rate is not None: | |
| self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate) | |
| def get_parameters(self, name): | |
| if name == 'encoder': | |
| return self.encoder.parameters() | |
| elif name == 'decoder': | |
| p = list(self.decoder.parameters()) | |
| if self.z_quantizer is not None: | |
| p += list(self.z_quantizer.parameters()) | |
| return p | |
| elif name == 'renderer': | |
| return self.renderer.parameters() | |
| elif name == 'zdm': | |
| return self.zdm_net.parameters() | |
| def encode(self, x): | |
| z = self.encoder(x) | |
| # print('z shape: ', z.shape) | |
| z = self.normalize_latents(z) | |
| # print('after norm z shape: ', z.shape) | |
| if (self.zaug_p is not None) and self.training: | |
| assert self.z_layernorm is not None # ensure 0 mean 1 std | |
| if self.zaug_tmax_always: | |
| tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax | |
| else: | |
| tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax | |
| zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz) | |
| mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float() | |
| if z.dim() == 4: # Image: [batch, channels, height, width] | |
| mask_shape = (-1, 1, 1, 1) | |
| elif z.dim() == 3: # Audio: [batch, channels, n_frames] | |
| mask_shape = (-1, 1, 1) | |
| else: | |
| raise ValueError(f"Unsupported tensor dimension: {z.dim()}") | |
| z = mask_aug.view(*mask_shape) * zt + (1 - mask_aug).view(*mask_shape) * z | |
| # z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z | |
| self._tz = tz | |
| self._mask_aug = mask_aug | |
| # print('after zaug z shape: ', z.shape) | |
| return z | |
| def decode(self, z): | |
| z_dec = self.decoder(z) | |
| return z_dec | |
| def render(self, z_dec): | |
| raise NotImplementedError | |
| def forward(self, data, mode, has_optimizer=None): | |
| loss = torch.tensor(0., device=data['inp'].device) | |
| ret = dict() | |
| # print("data['inp'] shape: ", data['inp'].shape) | |
| z = self.encode(data['inp']) | |
| z_dec = self.decode(z) | |
| ret['loss'] = loss | |
| return z_dec, ret | |
| def generate_samples( | |
| self, | |
| batch_size, | |
| n_steps, | |
| net_kwargs=None, | |
| uncond_net_kwargs=None, | |
| ema=False, | |
| guidance=1.0, | |
| noise=None, | |
| return_z=False, | |
| ): | |
| if self.zdm_force_guidance is not None: | |
| guidance = self.zdm_force_guidance | |
| shape = (batch_size,) + self.z_shape | |
| net = self.zdm_net if not ema else self.zdm_net_ema | |
| z = self.zdm_sampler.sample( | |
| net, | |
| shape, | |
| n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| guidance=guidance, | |
| noise=noise, | |
| ) | |
| if return_z: | |
| return z | |
| if (self.zaug_p is not None) and self.zaug_tmax_always: | |
| tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax | |
| z, _ = self.zaug_zdm_diffusion.add_noise(z, tz) | |
| z = self.denormalize_for_zdm(z) | |
| z_dec = self.decode(z) | |
| return self.render(z_dec) |