import torch from svd import StableVideoDiffusionPipeline from diffusers import DDIMScheduler from PIL import Image import numpy as np import torch.nn as nn import torch.nn.functional as F class StableVideoDiffusion: def __init__( self, device, fp16=True, t_range=[0.02, 0.98], ): super().__init__() self.guidance_type = [ 'sds', 'pixel reconstruction', 'latent reconstruction' ][1] self.device = device self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = StableVideoDiffusionPipeline.from_pretrained( "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" ) pipe.to(device) self.pipe = pipe self.num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps if self.guidance_type == 'sds' else 25 self.pipe.scheduler.set_timesteps(self.num_train_timesteps, device=device) # set sigma for euler discrete scheduling self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = None self.image = None self.target_cache = None @torch.no_grad() def get_img_embeds(self, image): self.image = Image.fromarray(np.uint8(image*255)) def encode_image(self, image): image = image * 2 -1 latents = self.pipe._encode_vae_image(image, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=False) latents = self.pipe.vae.config.scaling_factor * latents return latents def refine(self, pred_rgb, steps=25, strength=0.8, min_guidance_scale: float = 1.0, max_guidance_scale: float = 3.0, ): # strength = 0.8 batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) # interp to 512x512 to be fed into vae. pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! # latents = [] # for i in range(batch_size): # latent = self.encode_image(pred_rgb_512[i:i+1]) # latents.append(latent) # latents = torch.cat(latents, 0) latents = self.encode_image(pred_rgb_512) latents = latents.unsqueeze(0) if strength == 0: init_step = 0 latents = torch.randn_like(latents) else: init_step = int(steps * strength) latents = self.pipe.scheduler.add_noise(latents, torch.randn_like(latents), self.pipe.scheduler.timesteps[init_step:init_step+1]) target = self.pipe( image=self.image, height=512, width=512, latents=latents, denoise_beg=init_step, denoise_end=steps, output_type='frame', num_frames=batch_size, min_guidance_scale=min_guidance_scale, max_guidance_scale=max_guidance_scale, num_inference_steps=steps, decode_chunk_size=1 ).frames[0] target = (target + 1) * 0.5 target = target.permute(1,0,2,3) return target # frames = self.pipe( # image=self.image, # height=512, # width=512, # latents=latents, # denoise_beg=init_step, # denoise_end=steps, # num_frames=batch_size, # min_guidance_scale=min_guidance_scale, # max_guidance_scale=max_guidance_scale, # num_inference_steps=steps, # decode_chunk_size=1 # ).frames[0] # export_to_gif(frames, f"tmp.gif") # raise def train_step( self, pred_rgb, step_ratio=None, min_guidance_scale: float = 1.0, max_guidance_scale: float = 3.0, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) # interp to 512x512 to be fed into vae. pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! # latents = self.pipe._encode_image(pred_rgb_512, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True) latents = self.encode_image(pred_rgb_512) latents = latents.unsqueeze(0) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((1,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=self.device) # print(t) w = (1 - self.alphas[t]).view(1, 1, 1, 1) if self.guidance_type == 'sds': # predict the noise residual with unet, NO grad! with torch.no_grad(): t = self.num_train_timesteps - t.item() # add noise noise = torch.randn_like(latents) latents_noisy = self.pipe.scheduler.add_noise(latents, noise, self.pipe.scheduler.timesteps[t:t+1]) # t=0 noise;t=999 clean noise_pred = self.pipe( image=self.image, # image_embeddings=self.embeddings, height=512, width=512, latents=latents_noisy, output_type='noise', denoise_beg=t, denoise_end=t + 1, min_guidance_scale=min_guidance_scale, max_guidance_scale=max_guidance_scale, num_frames=batch_size, num_inference_steps=self.num_train_timesteps ).frames[0] grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[1] print(loss.item()) return loss elif self.guidance_type == 'pixel reconstruction': # pixel space reconstruction if self.target_cache is None: with torch.no_grad(): self.target_cache = self.pipe( image=self.image, height=512, width=512, output_type='frame', num_frames=batch_size, num_inference_steps=self.num_train_timesteps, decode_chunk_size=1 ).frames[0] self.target_cache = (self.target_cache + 1) * 0.5 self.target_cache = self.target_cache.permute(1,0,2,3) loss = 0.5 * F.mse_loss(pred_rgb_512.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1] print(loss.item()) return loss elif self.guidance_type == 'latent reconstruction': # latent space reconstruction if self.target_cache is None: with torch.no_grad(): self.target_cache = self.pipe( image=self.image, height=512, width=512, output_type='latent', num_frames=batch_size, num_inference_steps=self.num_train_timesteps, ).frames[0] loss = 0.5 * F.mse_loss(latents.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1] print(loss.item()) return loss