dreamgaussian4d / guidance /svd_utils.py
jiaweir
init
21c4e64
raw
history blame contribute delete
No virus
8.13 kB
import torch
from svd import StableVideoDiffusionPipeline
from diffusers import DDIMScheduler
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class StableVideoDiffusion:
def __init__(
self,
device,
fp16=True,
t_range=[0.02, 0.98],
):
super().__init__()
self.guidance_type = [
'sds',
'pixel reconstruction',
'latent reconstruction'
][1]
self.device = device
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
pipe.to(device)
self.pipe = pipe
self.num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps if self.guidance_type == 'sds' else 25
self.pipe.scheduler.set_timesteps(self.num_train_timesteps, device=device) # set sigma for euler discrete scheduling
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = None
self.image = None
self.target_cache = None
@torch.no_grad()
def get_img_embeds(self, image):
self.image = Image.fromarray(np.uint8(image*255))
def encode_image(self, image):
image = image * 2 -1
latents = self.pipe._encode_vae_image(image, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=False)
latents = self.pipe.vae.config.scaling_factor * latents
return latents
def refine(self,
pred_rgb,
steps=25, strength=0.8,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
):
# strength = 0.8
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
# latents = []
# for i in range(batch_size):
# latent = self.encode_image(pred_rgb_512[i:i+1])
# latents.append(latent)
# latents = torch.cat(latents, 0)
latents = self.encode_image(pred_rgb_512)
latents = latents.unsqueeze(0)
if strength == 0:
init_step = 0
latents = torch.randn_like(latents)
else:
init_step = int(steps * strength)
latents = self.pipe.scheduler.add_noise(latents, torch.randn_like(latents), self.pipe.scheduler.timesteps[init_step:init_step+1])
target = self.pipe(
image=self.image,
height=512,
width=512,
latents=latents,
denoise_beg=init_step,
denoise_end=steps,
output_type='frame',
num_frames=batch_size,
min_guidance_scale=min_guidance_scale,
max_guidance_scale=max_guidance_scale,
num_inference_steps=steps,
decode_chunk_size=1
).frames[0]
target = (target + 1) * 0.5
target = target.permute(1,0,2,3)
return target
# frames = self.pipe(
# image=self.image,
# height=512,
# width=512,
# latents=latents,
# denoise_beg=init_step,
# denoise_end=steps,
# num_frames=batch_size,
# min_guidance_scale=min_guidance_scale,
# max_guidance_scale=max_guidance_scale,
# num_inference_steps=steps,
# decode_chunk_size=1
# ).frames[0]
# export_to_gif(frames, f"tmp.gif")
# raise
def train_step(
self,
pred_rgb,
step_ratio=None,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
# latents = self.pipe._encode_image(pred_rgb_512, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True)
latents = self.encode_image(pred_rgb_512)
latents = latents.unsqueeze(0)
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((1,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=self.device)
# print(t)
w = (1 - self.alphas[t]).view(1, 1, 1, 1)
if self.guidance_type == 'sds':
# predict the noise residual with unet, NO grad!
with torch.no_grad():
t = self.num_train_timesteps - t.item()
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.pipe.scheduler.add_noise(latents, noise, self.pipe.scheduler.timesteps[t:t+1]) # t=0 noise;t=999 clean
noise_pred = self.pipe(
image=self.image,
# image_embeddings=self.embeddings,
height=512,
width=512,
latents=latents_noisy,
output_type='noise',
denoise_beg=t,
denoise_end=t + 1,
min_guidance_scale=min_guidance_scale,
max_guidance_scale=max_guidance_scale,
num_frames=batch_size,
num_inference_steps=self.num_train_timesteps
).frames[0]
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[1]
print(loss.item())
return loss
elif self.guidance_type == 'pixel reconstruction':
# pixel space reconstruction
if self.target_cache is None:
with torch.no_grad():
self.target_cache = self.pipe(
image=self.image,
height=512,
width=512,
output_type='frame',
num_frames=batch_size,
num_inference_steps=self.num_train_timesteps,
decode_chunk_size=1
).frames[0]
self.target_cache = (self.target_cache + 1) * 0.5
self.target_cache = self.target_cache.permute(1,0,2,3)
loss = 0.5 * F.mse_loss(pred_rgb_512.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1]
print(loss.item())
return loss
elif self.guidance_type == 'latent reconstruction':
# latent space reconstruction
if self.target_cache is None:
with torch.no_grad():
self.target_cache = self.pipe(
image=self.image,
height=512,
width=512,
output_type='latent',
num_frames=batch_size,
num_inference_steps=self.num_train_timesteps,
).frames[0]
loss = 0.5 * F.mse_loss(latents.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1]
print(loss.item())
return loss