StableVITON / cldm /cldm.py
rlawjdghek's picture
stableviton
80ccb59
raw
history blame
No virus
5.92 kB
import os
from os.path import join as opj
import omegaconf
import cv2
import einops
import torch
import torch as th
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import instantiate_from_config
class ControlLDM(LatentDiffusion):
def __init__(
self,
control_stage_config,
validation_config,
control_key,
only_mid_control,
use_VAEdownsample=False,
config_name="",
control_scales=None,
use_pbe_weight=False,
u_cond_percent=0.0,
img_H=512,
img_W=384,
always_learnable_param=False,
*args,
**kwargs
):
self.control_stage_config = control_stage_config
self.use_pbe_weight = use_pbe_weight
self.u_cond_percent = u_cond_percent
self.img_H = img_H
self.img_W = img_W
self.config_name = config_name
self.always_learnable_param = always_learnable_param
super().__init__(*args, **kwargs)
control_stage_config.params["use_VAEdownsample"] = use_VAEdownsample
self.control_model = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.only_mid_control = only_mid_control
if control_scales is None:
self.control_scales = [1.0] * 13
else:
self.control_scales = control_scales
self.first_stage_key_cond = kwargs.get("first_stage_key_cond", None)
self.valid_config = validation_config
self.use_VAEDownsample = use_VAEdownsample
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
if isinstance(self.control_key, omegaconf.listconfig.ListConfig):
control_lst = []
for key in self.control_key:
control = batch[key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
control_lst.append(control)
control = control_lst
else:
control = batch[self.control_key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
control = [control]
cond_dict = dict(c_crossattn=[c], c_concat=control)
if self.first_stage_key_cond is not None:
first_stage_cond = []
for key in self.first_stage_key_cond:
if not "mask" in key:
cond, _ = super().get_input(batch, key, *args, **kwargs)
else:
cond, _ = super().get_input(batch, key, no_latent=True, *args, **kwargs)
first_stage_cond.append(cond)
first_stage_cond = torch.cat(first_stage_cond, dim=1)
cond_dict["first_stage_cond"] = first_stage_cond
return x, cond_dict
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond["c_crossattn"], 1)
if self.proj_out is not None:
if cond_txt.shape[-1] == 1024:
cond_txt = self.proj_out(cond_txt) # [BS x 1 x 768]
if self.always_learnable_param:
cond_txt = self.get_unconditional_conditioning(cond_txt.shape[0])
if cond['c_concat'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
if "first_stage_cond" in cond:
x_noisy = torch.cat([x_noisy, cond["first_stage_cond"]], dim=1)
if not self.use_VAEDownsample:
hint = cond["c_concat"]
else:
hint = []
for h in cond["c_concat"]:
if h.shape[2] == self.img_H and h.shape[3] == self.img_W:
h = self.encode_first_stage(h)
h = self.get_first_stage_encoding(h).detach()
hint.append(h)
hint = torch.cat(hint, dim=1)
control, _ = self.control_model(x=x_noisy, hint=hint, timesteps=t, context=cond_txt, only_mid_control=self.only_mid_control)
if len(control) == len(self.control_scales):
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps, None
@torch.no_grad()
def get_unconditional_conditioning(self, N):
if not self.kwargs["use_imageCLIP"]:
return self.get_learned_conditioning([""] * N)
else:
return self.learnable_vector.repeat(N,1,1)
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()