DiffIR2VR / model /cldm.py
jimmycv07's picture
first commit
1de8821
raw
history blame
8.08 kB
from typing import Tuple, Set, List, Dict
import torch
from torch import nn
from model import (
ControlledUnetModel, ControlNet,
AutoencoderKL, FrozenOpenCLIPEmbedder
)
from utils.common import sliding_windows, count_vram_usage, gaussian_weights
def disabled_train(self: nn.Module) -> nn.Module:
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class ControlLDM(nn.Module):
def __init__(
self,
unet_cfg,
vae_cfg,
clip_cfg,
controlnet_cfg,
latent_scale_factor,
VidToMe_cfg,
latent_warp_cfg,
):
super().__init__()
self.unet = ControlledUnetModel(**unet_cfg)
self.vae = AutoencoderKL(**vae_cfg)
self.clip = FrozenOpenCLIPEmbedder(**clip_cfg)
self.controlnet = ControlNet(**controlnet_cfg)
self.scale_factor = latent_scale_factor
self.control_scales = [1.0] * 13
self.latent_control = latent_warp_cfg.latent_control
self.latent_warp_period = latent_warp_cfg.warp_period
self.latent_merge_period = latent_warp_cfg.merge_period
self.controller = None
self.ToMe_period = VidToMe_cfg.ToMe_period
self.merge_ratio = VidToMe_cfg.merge_ratio
self.merge_global = VidToMe_cfg.merge_global
self.global_merge_ratio = VidToMe_cfg.global_merge_ratio
self.seed = VidToMe_cfg.seed
self.batch_size = VidToMe_cfg.batch_size
self.align_batch = VidToMe_cfg.align_batch
self.global_rand = VidToMe_cfg.global_rand
if self.latent_control:
from controller.controller import AttentionControl
self.controller = AttentionControl(warp_period=self.latent_warp_period, \
merge_period=self.latent_merge_period, \
ToMe_period=self.ToMe_period, \
merge_ratio=self.merge_ratio, )
# if self.ToMe:
if self.ToMe_period[0] == 0:
print("[INFO] activate token merging ")
self.activate_vidtome()
def activate_vidtome(self):
import vidtome
# import ipdb; ipdb.set_trace()
vidtome.apply_patch(self, self.merge_ratio[0], self.merge_global, self.global_merge_ratio,
seed = self.seed, batch_size = self.batch_size, align_batch = self.align_batch, global_rand = self.global_rand)
@torch.no_grad()
def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]:
module_map = {
"unet": "model.diffusion_model",
"vae": "first_stage_model",
"clip": "cond_stage_model",
}
modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)]
used = set()
for name, module in modules:
init_sd = {}
scratch_sd = module.state_dict()
for key in scratch_sd:
target_key = ".".join([module_map[name], key])
init_sd[key] = sd[target_key].clone()
used.add(target_key)
module.load_state_dict(init_sd, strict=True)
unused = set(sd.keys()) - used
# NOTE: this is slightly different from previous version, which haven't switched
# the UNet to eval mode and disabled the requires_grad flag.
for module in [self.vae, self.clip, self.unet]:
module.eval()
module.train = disabled_train
for p in module.parameters():
p.requires_grad = False
return unused
@torch.no_grad()
def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None:
self.controlnet.load_state_dict(sd, strict=True)
@torch.no_grad()
def load_controlnet_from_unet(self) -> Tuple[Set[str]]:
unet_sd = self.unet.state_dict()
scratch_sd = self.controlnet.state_dict()
init_sd = {}
init_with_new_zero = set()
init_with_scratch = set()
for key in scratch_sd:
if key in unet_sd:
this, target = scratch_sd[key], unet_sd[key]
if this.size() == target.size():
init_sd[key] = target.clone()
else:
d_ic = this.size(1) - target.size(1)
oc, _, h, w = this.size()
zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype)
init_sd[key] = torch.cat((target, zeros), dim=1)
init_with_new_zero.add(key)
else:
init_sd[key] = scratch_sd[key].clone()
init_with_scratch.add(key)
self.controlnet.load_state_dict(init_sd, strict=True)
return init_with_new_zero, init_with_scratch
def vae_encode(self, image: torch.Tensor, sample: bool=True, batch_size: int=0) -> torch.Tensor:
if sample:
return self.vae.encode(image, batch_size=batch_size).sample() * self.scale_factor
else:
return self.vae.encode(image, batch_size=batch_size).mode() * self.scale_factor
def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True) -> torch.Tensor:
bs, _, h, w = image.shape
z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device)
count = torch.zeros_like(z, dtype=torch.float32)
weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
weights = torch.tensor(weights, dtype=torch.float32, device=image.device)
tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8)
for hi, hi_end, wi, wi_end in tiles:
tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8]
z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights
count[:, :, hi:hi_end, wi:wi_end] += weights
z.div_(count)
return z
def vae_decode(self, z: torch.Tensor, batch_size: int=0) -> torch.Tensor:
return self.vae.decode(z / self.scale_factor, batch_size=batch_size)
@count_vram_usage
def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor:
bs, _, h, w = z.shape
image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device)
count = torch.zeros_like(image, dtype=torch.float32)
weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None]
weights = torch.tensor(weights, dtype=torch.float32, device=z.device)
tiles = sliding_windows(h, w, tile_size, tile_stride)
for hi, hi_end, wi, wi_end in tiles:
tile_z = z[:, :, hi:hi_end, wi:wi_end]
image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights
count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights
image.div_(count)
return image
def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]:
return dict(
c_txt=self.clip.encode(txt),
c_img=self.vae_encode(clean * 2 - 1, sample=False, batch_size=5)
)
@count_vram_usage
def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[str, torch.Tensor]:
return dict(
c_txt=self.clip.encode(txt),
c_img=self.vae_encode_tiled(clean * 2 - 1, tile_size, tile_stride, sample=False)
)
def forward(self, x_noisy, t, cond):
c_txt = cond["c_txt"]
c_img = cond["c_img"]
control = self.controlnet(
x_noisy, hint=c_img,
timesteps=t, context=c_txt
)
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = self.unet(
x_noisy, timesteps=t,
context=c_txt, control=control, only_mid_control=False
)
return eps