Spaces:
Running
on
A10G
Running
on
A10G
from omegaconf import OmegaConf | |
import torch as th | |
import torch | |
import math | |
import abc | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
from transformers import CLIPTokenizer | |
from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPTextTransformer#, _expand_mask | |
from inspect import isfunction | |
def exists(val): | |
return val is not None | |
def uniq(arr): | |
return{el: True for el in arr}.keys() | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def register_attention_control(model, controller): | |
def ca_forward(self, place_in_unet): | |
def forward(x, context=None, mask=None): | |
h = self.heads | |
q = self.to_q(x) | |
is_cross = context is not None | |
context = default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | |
if exists(mask): | |
mask = rearrange(mask, 'b ... -> b (...)') | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
# attention, what we cannot get enough of | |
attn = sim.softmax(dim=-1) | |
attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0) | |
controller(attn2, is_cross, place_in_unet) | |
out = einsum('b i j, b j d -> b i d', attn, v) | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | |
return self.to_out(out) | |
return forward | |
class DummyController: | |
def __call__(self, *args): | |
return args[0] | |
def __init__(self): | |
self.num_att_layers = 0 | |
if controller is None: | |
controller = DummyController() | |
def register_recr(net_, count, place_in_unet): | |
if net_.__class__.__name__ == 'CrossAttention': | |
net_.forward = ca_forward(net_, place_in_unet) | |
return count + 1 | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
count = register_recr(net__, count, place_in_unet) | |
return count | |
cross_att_count = 0 | |
sub_nets = model.diffusion_model.named_children() | |
for net in sub_nets: | |
if "input_blocks" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "down") | |
elif "output_blocks" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "up") | |
elif "middle_block" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "mid") | |
controller.num_att_layers = cross_att_count | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def num_uncond_att_layers(self): | |
return 0 | |
def forward (self, attn, is_cross: bool, place_in_unet: str): | |
raise NotImplementedError | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
attn = self.forward(attn, is_cross, place_in_unet) | |
return attn | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
class AttentionStore(AttentionControl): | |
def get_empty_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
"down_self": [], "mid_self": [], "up_self": []} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
if attn.shape[1] <= (self.max_size) ** 2: # avoid memory overhead | |
self.step_store[key].append(attn) | |
return attn | |
def between_steps(self): | |
if len(self.attention_store) == 0: | |
self.attention_store = self.step_store | |
else: | |
for key in self.attention_store: | |
for i in range(len(self.attention_store[key])): | |
self.attention_store[key][i] += self.step_store[key][i] | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store} | |
return average_attention | |
def reset(self): | |
super(AttentionStore, self).reset() | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
def __init__(self, base_size=64, max_size=None): | |
super(AttentionStore, self).__init__() | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
self.base_size = base_size | |
if max_size is None: | |
self.max_size = self.base_size // 2 | |
else: | |
self.max_size = max_size | |
def register_hier_output(model): | |
self = model.diffusion_model | |
from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding | |
def forward(x, timesteps=None, context=None, y=None,**kwargs): | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
:param y: an [N] Tensor of labels, if class-conditional. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
assert (y is not None) == ( | |
self.num_classes is not None | |
), "must specify y if and only if the model is class-conditional" | |
hs = [] | |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
emb = self.time_embed(t_emb) | |
if self.num_classes is not None: | |
assert y.shape == (x.shape[0],) | |
emb = emb + self.label_emb(y) | |
h = x.type(self.dtype) | |
for module in self.input_blocks: | |
# import pdb; pdb.set_trace() | |
if context.shape[1]==2: | |
h = module(h, emb, context[:,0,:].unsqueeze(1)) | |
else: | |
h = module(h, emb, context) | |
hs.append(h) | |
if context.shape[1]==2: | |
h = self.middle_block(h, emb, context[:,0,:].unsqueeze(1)) | |
else: | |
h = self.middle_block(h, emb, context) | |
out_list = [] | |
for i_out, module in enumerate(self.output_blocks): | |
h = th.cat([h, hs.pop()], dim=1) | |
if context.shape[1]==2: | |
h = module(h, emb, context[:,1,:].unsqueeze(1)) | |
else: | |
h = module(h, emb, context) | |
if i_out in [1, 4, 7]: | |
out_list.append(h) | |
h = h.type(x.dtype) | |
out_list.append(h) | |
return out_list | |
self.forward = forward | |
class UNetWrapper(nn.Module): | |
def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None: | |
super().__init__() | |
self.unet = unet | |
self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size) | |
self.size16 = base_size // 32 | |
self.size32 = base_size // 16 | |
self.size64 = base_size // 8 | |
self.use_attn = use_attn | |
if self.use_attn: | |
register_attention_control(unet, self.attention_store) | |
register_hier_output(unet) | |
self.attn_selector = attn_selector.split('+') | |
def forward(self, *args, **kwargs): | |
if self.use_attn: | |
self.attention_store.reset() | |
out_list = self.unet(*args, **kwargs) | |
if self.use_attn: | |
avg_attn = self.attention_store.get_average_attention() | |
attn16, attn32, attn64 = self.process_attn(avg_attn) | |
out_list[1] = torch.cat([out_list[1], attn16], dim=1) | |
out_list[2] = torch.cat([out_list[2], attn32], dim=1) | |
if attn64 is not None: | |
out_list[3] = torch.cat([out_list[3], attn64], dim=1) | |
return out_list[::-1] | |
def process_attn(self, avg_attn): | |
attns = {self.size16: [], self.size32: [], self.size64: []} | |
for k in self.attn_selector: | |
for up_attn in avg_attn[k]: | |
size = int(math.sqrt(up_attn.shape[1])) | |
attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size)) | |
attn16 = torch.stack(attns[self.size16]).mean(0) | |
attn32 = torch.stack(attns[self.size32]).mean(0) | |
if len(attns[self.size64]) > 0: | |
attn64 = torch.stack(attns[self.size64]).mean(0) | |
else: | |
attn64 = None | |
return attn16, attn32, attn64 | |
class TextAdapter(nn.Module): | |
def __init__(self, text_dim=768, hidden_dim=None): | |
super().__init__() | |
if hidden_dim is None: | |
hidden_dim = text_dim | |
self.fc = nn.Sequential( | |
nn.Linear(text_dim, hidden_dim), | |
nn.GELU(), | |
nn.Linear(hidden_dim, text_dim) | |
) | |
def forward(self, latents, texts, gamma): | |
n_class, channel = texts.shape | |
bs = latents.shape[0] | |
texts_after = self.fc(texts) | |
texts = texts + gamma * texts_after | |
texts = repeat(texts, 'n c -> b n c', b=bs) | |
return texts | |
class TextAdapterRefer(nn.Module): | |
def __init__(self, text_dim=768): | |
super().__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(text_dim, text_dim), | |
nn.GELU(), | |
nn.Linear(text_dim, text_dim) | |
) | |
def forward(self, latents, texts, gamma): | |
texts_after = self.fc(texts) | |
texts = texts + gamma * texts_after | |
return texts | |
class TextAdapterDepth(nn.Module): | |
def __init__(self, text_dim=768): | |
super().__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(text_dim, text_dim), | |
nn.GELU(), | |
nn.Linear(text_dim, text_dim) | |
) | |
def forward(self, latents, texts, gamma): | |
# use the gamma to blend | |
n_sen, channel = texts.shape | |
bs = latents.shape[0] | |
texts_after = self.fc(texts) | |
texts = texts + gamma * texts_after | |
texts = repeat(texts, 'n c -> n b c', b=1) | |
return texts | |
class FrozenCLIPEmbedder(nn.Module): | |
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" | |
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, pool=True): | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.transformer = CLIPTextModel.from_pretrained(version) | |
self.device = device | |
self.max_length = max_length | |
self.freeze() | |
self.pool = pool | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"].to(self.device) | |
outputs = self.transformer(input_ids=tokens) | |
if self.pool: | |
z = outputs.pooler_output | |
else: | |
z = outputs.last_hidden_state | |
return z | |
def encode(self, text): | |
return self(text) | |