Spaces:
Runtime error
Runtime error
File size: 5,942 Bytes
80ccb59 8df522a 80ccb59 8df522a 80ccb59 8df522a 80ccb59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
from os.path import join as opj
import omegaconf
import cv2
import einops
import torch
import torch as th
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import instantiate_from_config
class ControlLDM(LatentDiffusion):
def __init__(
self,
control_stage_config,
validation_config,
control_key,
only_mid_control,
use_VAEdownsample=False,
config_name="",
control_scales=None,
use_pbe_weight=False,
u_cond_percent=0.0,
img_H=512,
img_W=384,
always_learnable_param=False,
*args,
**kwargs
):
self.device = torch.device("cuda")
self.control_stage_config = control_stage_config
self.use_pbe_weight = use_pbe_weight
self.u_cond_percent = u_cond_percent
self.img_H = img_H
self.img_W = img_W
self.config_name = config_name
self.always_learnable_param = always_learnable_param
super().__init__(*args, **kwargs)
control_stage_config.params["use_VAEdownsample"] = use_VAEdownsample
self.control_model = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.only_mid_control = only_mid_control
if control_scales is None:
self.control_scales = [1.0] * 13
else:
self.control_scales = control_scales
self.first_stage_key_cond = kwargs.get("first_stage_key_cond", None)
self.valid_config = validation_config
self.use_VAEDownsample = use_VAEdownsample
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
if isinstance(self.control_key, omegaconf.listconfig.ListConfig):
control_lst = []
for key in self.control_key:
control = batch[key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format)
control_lst.append(control)
control = control_lst
else:
control = batch[self.control_key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format)
control = [control]
cond_dict = dict(c_crossattn=[c], c_concat=control)
if self.first_stage_key_cond is not None:
first_stage_cond = []
for key in self.first_stage_key_cond:
if not "mask" in key:
cond, _ = super().get_input(batch, key, *args, **kwargs)
else:
cond, _ = super().get_input(batch, key, no_latent=True, *args, **kwargs)
first_stage_cond.append(cond)
first_stage_cond = torch.cat(first_stage_cond, dim=1)
cond_dict["first_stage_cond"] = first_stage_cond
return x, cond_dict
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond["c_crossattn"], 1)
if self.proj_out is not None:
if cond_txt.shape[-1] == 1024:
cond_txt = self.proj_out(cond_txt) # [BS x 1 x 768]
if self.always_learnable_param:
cond_txt = self.get_unconditional_conditioning(cond_txt.shape[0])
if cond['c_concat'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
if "first_stage_cond" in cond:
x_noisy = torch.cat([x_noisy, cond["first_stage_cond"]], dim=1)
if not self.use_VAEDownsample:
hint = cond["c_concat"]
else:
hint = []
for h in cond["c_concat"]:
if h.shape[2] == self.img_H and h.shape[3] == self.img_W:
h = self.encode_first_stage(h)
h = self.get_first_stage_encoding(h).detach()
hint.append(h)
hint = torch.cat(hint, dim=1)
control, _ = self.control_model(x=x_noisy, hint=hint, timesteps=t, context=cond_txt, only_mid_control=self.only_mid_control)
if len(control) == len(self.control_scales):
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps, None
@torch.no_grad()
def get_unconditional_conditioning(self, N):
if not self.kwargs["use_imageCLIP"]:
return self.get_learned_conditioning([""] * N)
else:
return self.learnable_vector.repeat(N,1,1)
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda() |