from omegaconf import OmegaConf import torch as th import torch import math import abc from torch import nn, einsum from einops import rearrange, repeat from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from transformers import CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPTextTransformer#, _expand_mask from inspect import isfunction def exists(val): return val is not None def uniq(arr): return{el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def register_attention_control(model, controller): def ca_forward(self, place_in_unet): def forward(x, context=None, mask=None): h = self.heads q = self.to_q(x) is_cross = context is not None context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0) controller(attn2, is_cross, place_in_unet) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) return forward class DummyController: def __call__(self, *args): return args[0] def __init__(self): self.num_att_layers = 0 if controller is None: controller = DummyController() def register_recr(net_, count, place_in_unet): if net_.__class__.__name__ == 'CrossAttention': net_.forward = ca_forward(net_, place_in_unet) return count + 1 elif hasattr(net_, 'children'): for net__ in net_.children(): count = register_recr(net__, count, place_in_unet) return count cross_att_count = 0 sub_nets = model.diffusion_model.named_children() for net in sub_nets: if "input_blocks" in net[0]: cross_att_count += register_recr(net[1], 0, "down") elif "output_blocks" in net[0]: cross_att_count += register_recr(net[1], 0, "up") elif "middle_block" in net[0]: cross_att_count += register_recr(net[1], 0, "mid") controller.num_att_layers = cross_att_count class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return 0 @abc.abstractmethod def forward (self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): attn = self.forward(attn, is_cross, place_in_unet) return attn def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" if attn.shape[1] <= (self.max_size) ** 2: # avoid memory overhead self.step_store[key].append(attn) return attn def between_steps(self): if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store} return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self, base_size=64, max_size=None): super(AttentionStore, self).__init__() self.step_store = self.get_empty_store() self.attention_store = {} self.base_size = base_size if max_size is None: self.max_size = self.base_size // 2 else: self.max_size = max_size def register_hier_output(model): self = model.diffusion_model from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding def forward(x, timesteps=None, context=None, y=None,**kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) h = x.type(self.dtype) for module in self.input_blocks: # import pdb; pdb.set_trace() if context.shape[1]==2: h = module(h, emb, context[:,0,:].unsqueeze(1)) else: h = module(h, emb, context) hs.append(h) if context.shape[1]==2: h = self.middle_block(h, emb, context[:,0,:].unsqueeze(1)) else: h = self.middle_block(h, emb, context) out_list = [] for i_out, module in enumerate(self.output_blocks): h = th.cat([h, hs.pop()], dim=1) if context.shape[1]==2: h = module(h, emb, context[:,1,:].unsqueeze(1)) else: h = module(h, emb, context) if i_out in [1, 4, 7]: out_list.append(h) h = h.type(x.dtype) out_list.append(h) return out_list self.forward = forward class UNetWrapper(nn.Module): def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None: super().__init__() self.unet = unet self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size) self.size16 = base_size // 32 self.size32 = base_size // 16 self.size64 = base_size // 8 self.use_attn = use_attn if self.use_attn: register_attention_control(unet, self.attention_store) register_hier_output(unet) self.attn_selector = attn_selector.split('+') def forward(self, *args, **kwargs): if self.use_attn: self.attention_store.reset() out_list = self.unet(*args, **kwargs) if self.use_attn: avg_attn = self.attention_store.get_average_attention() attn16, attn32, attn64 = self.process_attn(avg_attn) out_list[1] = torch.cat([out_list[1], attn16], dim=1) out_list[2] = torch.cat([out_list[2], attn32], dim=1) if attn64 is not None: out_list[3] = torch.cat([out_list[3], attn64], dim=1) return out_list[::-1] def process_attn(self, avg_attn): attns = {self.size16: [], self.size32: [], self.size64: []} for k in self.attn_selector: for up_attn in avg_attn[k]: size = int(math.sqrt(up_attn.shape[1])) attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size)) attn16 = torch.stack(attns[self.size16]).mean(0) attn32 = torch.stack(attns[self.size32]).mean(0) if len(attns[self.size64]) > 0: attn64 = torch.stack(attns[self.size64]).mean(0) else: attn64 = None return attn16, attn32, attn64 class TextAdapter(nn.Module): def __init__(self, text_dim=768, hidden_dim=None): super().__init__() if hidden_dim is None: hidden_dim = text_dim self.fc = nn.Sequential( nn.Linear(text_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, text_dim) ) def forward(self, latents, texts, gamma): n_class, channel = texts.shape bs = latents.shape[0] texts_after = self.fc(texts) texts = texts + gamma * texts_after texts = repeat(texts, 'n c -> b n c', b=bs) return texts class TextAdapterRefer(nn.Module): def __init__(self, text_dim=768): super().__init__() self.fc = nn.Sequential( nn.Linear(text_dim, text_dim), nn.GELU(), nn.Linear(text_dim, text_dim) ) def forward(self, latents, texts, gamma): texts_after = self.fc(texts) texts = texts + gamma * texts_after return texts class TextAdapterDepth(nn.Module): def __init__(self, text_dim=768): super().__init__() self.fc = nn.Sequential( nn.Linear(text_dim, text_dim), nn.GELU(), nn.Linear(text_dim, text_dim) ) def forward(self, latents, texts, gamma): # use the gamma to blend n_sen, channel = texts.shape bs = latents.shape[0] texts_after = self.fc(texts) texts = texts + gamma * texts_after texts = repeat(texts, 'n c -> n b c', b=1) return texts class FrozenCLIPEmbedder(nn.Module): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, pool=True): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length self.freeze() self.pool = pool def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) if self.pool: z = outputs.pooler_output else: z = outputs.last_hidden_state return z def encode(self, text): return self(text)