from src.utils import * from src.flow_utils import warp_tensor import torch import torchvision import gc """ ========================================================================== * step(): one DDPM step with background smoothing * inference(): translate one batch with FRESCO and background smoothing ========================================================================== """ def step(pipe, model_output, timestep, sample, generator, repeat_noise=False, visualize_pipeline=False, flows=None, occs=None, saliency=None): """ DDPM step with background smoothing * background smoothing: warp the background region of the previous frame to the current frame """ scheduler = pipe.scheduler # 1. get previous step value (=t-1) prev_timestep = scheduler.previous_timestep(timestep) # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) """ [HACK] add background smoothing decode the feature warp the feature of f_{i-1} fuse the warped f_{i-1} with f_{i} in the non-salient region (i.e., background) encode the fused feature """ if saliency is not None and flows is not None and occs is not None: image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample image = warp_tensor(image, flows, occs, saliency, unet_chunk_size=1) pred_original_sample = pipe.vae.config.scaling_factor * pipe.vae.encode(image).latent_dist.sample() # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample ยต_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample variance = beta_prod_t_prev / beta_prod_t * current_beta_t variance = torch.clamp(variance, min=1e-20) variance = (variance ** 0.5) * torch.randn(model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) """ [HACK] background smoothing applying the same noise could be good for static background """ if repeat_noise: variance = variance[0:1].repeat(model_output.shape[0],1,1,1) if visualize_pipeline: # for debug image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample viz = torchvision.utils.make_grid(torch.clamp(image, -1, 1), image.shape[0], 1) visualize(viz.cpu(), 90) pred_prev_sample = pred_prev_sample + variance return (pred_prev_sample, pred_original_sample) @torch.no_grad() def inference(pipe, controlnet, frescoProc, imgs, prompt_embeds, edges, timesteps, cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6, do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True, record_latents=[], propagation_mode=False, visualize_pipeline=False, flows = None, occs = None, saliency=None, repeat_noise=False, num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]): """ video-to-video translation inference pipeline with FRESCO * add controlnet and SDEdit * add FRESCO-guided attention * add FRESCO-guided optimization * add background smoothing * add support for inter-batch long video translation [input of the original pipe] pipe: base diffusion model imgs: a batch of the input frames prompt_embeds: prompts num_inference_steps: number of DDPM steps timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps) do_classifier_free_guidance: cfg, should be always true guidance_scale: cfg scale seed [input of SDEdit] num_warmup_steps: skip the first num_warmup_steps DDPM steps [input of controlnet] use_controlnet: bool, whether using controlnet controlnet: controlnet model edges: input for controlnet (edge/stroke/depth, etc.) cond_scale: controlnet scale [input of FRESCO] frescoProc: FRESCO attention controller flows: optical flows occs: occlusion mask num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps [input for background smoothing] saliency: saliency mask repeat_noise: bool, use the same noise for all frames bg_smoothing_steps: apply background smoothing in bg_smoothing_steps [input for long video translation] record_latents: recorded latents in the last batch propagation_mode: bool, whether this is not the first batch [output] latents: a batch of latents of the translated frames """ gc.collect() torch.cuda.empty_cache() device = pipe._execution_device noise_scheduler = pipe.scheduler generator = torch.Generator(device=device).manual_seed(seed) B, C, H, W = imgs.shape latents = pipe.prepare_latents( B, pipe.unet.config.in_channels, H, W, prompt_embeds.dtype, device, generator, latents = None, ) if repeat_noise: latents = latents[0:1].repeat(B,1,1,1).detach() if num_warmup_steps < 0: latents_init = latents.detach() num_warmup_steps = 0 else: # SDEdit, use the noisy latent of imges as the input rather than a pure gausssian noise latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample() latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach() # SDEdit, run num_inference_steps-num_warmup_steps steps with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar: latents = latents_init for i, t in enumerate(timesteps[num_warmup_steps:]): """ [HACK] control the steps to apply spatial/temporal-guided attention [HACK] record and restore latents from previous batch """ if i >= num_intraattn_steps: frescoProc.controller.disable_intraattn() if t < step_interattn_end: frescoProc.controller.disable_interattn() if propagation_mode: # restore latent from previous batch and record latent of the current batch latents[0:2] = record_latents[i].detach().clone() record_latents[i] = latents[[0,len(latents)-1]].detach().clone() else: # frist batch, record_latents[0][t] = [x_1,t, x_{N,t}] record_latents += [latents[[0,len(latents)-1]].detach().clone()] # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if use_controlnet: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds down_block_res_samples, mid_block_res_sample = controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=edges, conditioning_scale=cond_scale[i+num_warmup_steps], guess_mode=False, return_dict=False, ) else: down_block_res_samples, mid_block_res_sample = None, None # predict the noise residual noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=None, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 """ [HACK] background smoothing Note: bg_smoothing_steps should be rescaled based on num_inference_steps current [16,17] is based on num_inference_steps=20 """ if i + num_warmup_steps in bg_smoothing_steps: latents = step(pipe, noise_pred, t, latents, generator, visualize_pipeline=visualize_pipeline, flows = flows, occs = occs, saliency=saliency)[0] else: latents = step(pipe, noise_pred, t, latents, generator, visualize_pipeline=visualize_pipeline)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0): progress_bar.update() return latents