| |
| |
| |
| |
| |
|
|
| import math |
| from typing import List, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmengine.model import BaseModule |
| from mmengine.runner import CheckpointLoader, load_checkpoint |
|
|
| from mmseg.registry import MODELS |
| from mmseg.utils import ConfigType, OptConfigType |
|
|
| try: |
| from ldm.modules.diffusionmodules.util import timestep_embedding |
| from ldm.util import instantiate_from_config |
| has_ldm = True |
| except ImportError: |
| has_ldm = False |
|
|
|
|
| def register_attention_control(model, controller): |
| """Registers a control function to manage attention within a model. |
| |
| Args: |
| model: The model to which attention is to be registered. |
| controller: The control function responsible for managing attention. |
| """ |
|
|
| def ca_forward(self, place_in_unet): |
| """Custom forward method for attention. |
| |
| Args: |
| self: Reference to the current object. |
| place_in_unet: The location in UNet (down/mid/up). |
| |
| Returns: |
| The modified forward method. |
| """ |
|
|
| def forward(x, context=None, mask=None): |
| h = self.heads |
| is_cross = context is not None |
| context = context or x |
|
|
| q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) |
| q, k, v = ( |
| tensor.view(tensor.shape[0] * h, tensor.shape[1], |
| tensor.shape[2] // h) for tensor in [q, k, v]) |
|
|
| sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
|
| if mask is not None: |
| mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) |
| max_neg_value = -torch.finfo(sim.dtype).max |
| sim.masked_fill_(~mask, max_neg_value) |
|
|
| attn = sim.softmax(dim=-1) |
| attn_mean = attn.view(h, attn.shape[0] // h, |
| *attn.shape[1:]).mean(0) |
| controller(attn_mean, is_cross, place_in_unet) |
|
|
| out = torch.matmul(attn, v) |
| out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) |
| return self.to_out(out) |
|
|
| return forward |
|
|
| def register_recr(net_, count, place_in_unet): |
| """Recursive function to register the custom forward method to all |
| CrossAttention layers. |
| |
| Args: |
| net_: The network layer currently being processed. |
| count: The current count of layers processed. |
| place_in_unet: The location in UNet (down/mid/up). |
| |
| Returns: |
| The updated count of layers processed. |
| """ |
| if net_.__class__.__name__ == 'CrossAttention': |
| net_.forward = ca_forward(net_, place_in_unet) |
| return count + 1 |
| if hasattr(net_, 'children'): |
| return sum( |
| register_recr(child, 0, place_in_unet) |
| for child in net_.children()) |
| return count |
|
|
| cross_att_count = sum( |
| register_recr(net[1], 0, place) for net, place in [ |
| (child, 'down') if 'input_blocks' in name else ( |
| child, 'up') if 'output_blocks' in name else |
| (child, |
| 'mid') if 'middle_block' in name else (None, None) |
| for name, child in model.diffusion_model.named_children() |
| ] if net is not None) |
|
|
| controller.num_att_layers = cross_att_count |
|
|
|
|
| class AttentionStore: |
| """A class for storing attention information in the UNet model. |
| |
| Attributes: |
| base_size (int): Base size for storing attention information. |
| max_size (int): Maximum size for storing attention information. |
| """ |
|
|
| def __init__(self, base_size=64, max_size=None): |
| """Initialize AttentionStore with default or custom sizes.""" |
| self.reset() |
| self.base_size = base_size |
| self.max_size = max_size or (base_size // 2) |
| self.num_att_layers = -1 |
|
|
| @staticmethod |
| def get_empty_store(): |
| """Returns an empty store for holding attention values.""" |
| return { |
| key: [] |
| for key in [ |
| 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', |
| 'up_self' |
| ] |
| } |
|
|
| def reset(self): |
| """Resets the step and attention stores to their initial states.""" |
| self.cur_step = 0 |
| self.cur_att_layer = 0 |
| self.step_store = self.get_empty_store() |
| self.attention_store = {} |
|
|
| def forward(self, attn, is_cross: bool, place_in_unet: str): |
| """Processes a single forward step, storing the attention. |
| |
| Args: |
| attn: The attention tensor. |
| is_cross (bool): Whether it's cross attention. |
| place_in_unet (str): The location in UNet (down/mid/up). |
| |
| Returns: |
| The unmodified attention tensor. |
| """ |
| key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" |
| if attn.shape[1] <= (self.max_size)**2: |
| self.step_store[key].append(attn) |
| return attn |
|
|
| def between_steps(self): |
| """Processes and stores attention information between steps.""" |
| if not self.attention_store: |
| self.attention_store = self.step_store |
| else: |
| for key in self.attention_store: |
| self.attention_store[key] = [ |
| stored + step for stored, step in zip( |
| self.attention_store[key], self.step_store[key]) |
| ] |
| self.step_store = self.get_empty_store() |
|
|
| def get_average_attention(self): |
| """Calculates and returns the average attention across all steps.""" |
| return { |
| key: [item for item in self.step_store[key]] |
| for key in self.step_store |
| } |
|
|
| def __call__(self, attn, is_cross: bool, place_in_unet: str): |
| """Allows the class instance to be callable.""" |
| return self.forward(attn, is_cross, place_in_unet) |
|
|
| @property |
| def num_uncond_att_layers(self): |
| """Returns the number of unconditional attention layers (default is |
| 0).""" |
| return 0 |
|
|
| def step_callback(self, x_t): |
| """A placeholder for a step callback. |
| |
| Returns the input unchanged. |
| """ |
| return x_t |
|
|
|
|
| class UNetWrapper(nn.Module): |
| """A wrapper for UNet with optional attention mechanisms. |
| |
| Args: |
| unet (nn.Module): The UNet model to wrap |
| use_attn (bool): Whether to use attention. Defaults to True |
| base_size (int): Base size for the attention store. Defaults to 512 |
| max_attn_size (int, optional): Maximum size for the attention store. |
| Defaults to None |
| attn_selector (str): The types of attention to use. |
| Defaults to 'up_cross+down_cross' |
| """ |
|
|
| def __init__(self, |
| unet, |
| use_attn=True, |
| base_size=512, |
| max_attn_size=None, |
| attn_selector='up_cross+down_cross'): |
| super().__init__() |
|
|
| assert has_ldm, 'To use UNetWrapper, please install required ' \ |
| 'packages via `pip install -r requirements/optional.txt`.' |
|
|
| self.unet = unet |
| self.attention_store = AttentionStore( |
| base_size=base_size // 8, max_size=max_attn_size) |
| self.attn_selector = attn_selector.split('+') |
| self.use_attn = use_attn |
| self.init_sizes(base_size) |
| if self.use_attn: |
| register_attention_control(unet, self.attention_store) |
|
|
| def init_sizes(self, base_size): |
| """Initialize sizes based on the base size.""" |
| self.size16 = base_size // 32 |
| self.size32 = base_size // 16 |
| self.size64 = base_size // 8 |
|
|
| def forward(self, x, timesteps=None, context=None, y=None, **kwargs): |
| """Forward pass through the model.""" |
| diffusion_model = self.unet.diffusion_model |
| if self.use_attn: |
| self.attention_store.reset() |
| hs, emb, out_list = self._unet_forward(x, timesteps, context, y, |
| diffusion_model) |
| if self.use_attn: |
| self._append_attn_to_output(out_list) |
| return out_list[::-1] |
|
|
| def _unet_forward(self, x, timesteps, context, y, diffusion_model): |
| hs = [] |
| t_emb = timestep_embedding( |
| timesteps, diffusion_model.model_channels, repeat_only=False) |
| emb = diffusion_model.time_embed(t_emb) |
| h = x.type(diffusion_model.dtype) |
| for module in diffusion_model.input_blocks: |
| h = module(h, emb, context) |
| hs.append(h) |
| h = diffusion_model.middle_block(h, emb, context) |
| out_list = [] |
| for i_out, module in enumerate(diffusion_model.output_blocks): |
| h = torch.cat([h, hs.pop()], dim=1) |
| h = module(h, emb, context) |
| if i_out in [1, 4, 7]: |
| out_list.append(h) |
| h = h.type(x.dtype) |
| out_list.append(h) |
| return hs, emb, out_list |
|
|
| def _append_attn_to_output(self, out_list): |
| avg_attn = self.attention_store.get_average_attention() |
| attns = {self.size16: [], self.size32: [], self.size64: []} |
| for k in self.attn_selector: |
| for up_attn in avg_attn[k]: |
| size = int(math.sqrt(up_attn.shape[1])) |
| up_attn = up_attn.transpose(-1, -2).reshape( |
| *up_attn.shape[:2], size, -1) |
| attns[size].append(up_attn) |
| attn16 = torch.stack(attns[self.size16]).mean(0) |
| attn32 = torch.stack(attns[self.size32]).mean(0) |
| attn64 = torch.stack(attns[self.size64]).mean(0) if len( |
| attns[self.size64]) > 0 else None |
| out_list[1] = torch.cat([out_list[1], attn16], dim=1) |
| out_list[2] = torch.cat([out_list[2], attn32], dim=1) |
| if attn64 is not None: |
| out_list[3] = torch.cat([out_list[3], attn64], dim=1) |
|
|
|
|
| class TextAdapter(nn.Module): |
| """A PyTorch Module that serves as a text adapter. |
| |
| This module takes text embeddings and adjusts them based on a scaling |
| factor gamma. |
| """ |
|
|
| def __init__(self, text_dim=768): |
| super().__init__() |
| self.fc = nn.Sequential( |
| nn.Linear(text_dim, text_dim), nn.GELU(), |
| nn.Linear(text_dim, text_dim)) |
|
|
| def forward(self, texts, gamma): |
| texts_after = self.fc(texts) |
| texts = texts + gamma * texts_after |
| return texts |
|
|
|
|
| @MODELS.register_module() |
| class VPD(BaseModule): |
| """VPD (Visual Perception Diffusion) model. |
| |
| .. _`VPD`: https://arxiv.org/abs/2303.02153 |
| |
| Args: |
| diffusion_cfg (dict): Configuration for diffusion model. |
| class_embed_path (str): Path for class embeddings. |
| unet_cfg (dict, optional): Configuration for U-Net. |
| gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. |
| class_embed_select (bool, optional): If True, enables class embedding |
| selection. Defaults to False. |
| pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. |
| Defaults to None. |
| pad_val (Union[int, List[int]], optional): Padding value. |
| Defaults to 0. |
| init_cfg (dict, optional): Configuration for network initialization. |
| """ |
|
|
| def __init__(self, |
| diffusion_cfg: ConfigType, |
| class_embed_path: str, |
| unet_cfg: OptConfigType = dict(), |
| gamma: float = 1e-4, |
| class_embed_select=False, |
| pad_shape: Optional[Union[int, List[int]]] = None, |
| pad_val: Union[int, List[int]] = 0, |
| init_cfg: OptConfigType = None): |
|
|
| super().__init__(init_cfg=init_cfg) |
|
|
| assert has_ldm, 'To use VPD model, please install required packages' \ |
| ' via `pip install -r requirements/optional.txt`.' |
|
|
| if pad_shape is not None: |
| if not isinstance(pad_shape, (list, tuple)): |
| pad_shape = (pad_shape, pad_shape) |
|
|
| self.pad_shape = pad_shape |
| self.pad_val = pad_val |
|
|
| |
| diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) |
| sd_model = instantiate_from_config(diffusion_cfg) |
| if diffusion_checkpoint is not None: |
| load_checkpoint(sd_model, diffusion_checkpoint, strict=False) |
|
|
| self.encoder_vq = sd_model.first_stage_model |
| self.unet = UNetWrapper(sd_model.model, **unet_cfg) |
|
|
| |
| class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) |
| text_dim = class_embeddings.size(-1) |
| self.text_adapter = TextAdapter(text_dim=text_dim) |
| self.class_embed_select = class_embed_select |
| if class_embed_select: |
| class_embeddings = torch.cat( |
| (class_embeddings, class_embeddings.mean(dim=0, |
| keepdims=True)), |
| dim=0) |
| self.register_buffer('class_embeddings', class_embeddings) |
| self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) |
|
|
| def forward(self, x): |
| """Extract features from images.""" |
|
|
| |
| if self.class_embed_select: |
| if isinstance(x, (tuple, list)): |
| x, class_ids = x[:2] |
| class_ids = class_ids.tolist() |
| else: |
| class_ids = [-1] * x.size(0) |
| class_embeddings = self.class_embeddings[class_ids] |
| c_crossattn = self.text_adapter(class_embeddings, self.gamma) |
| c_crossattn = c_crossattn.unsqueeze(1) |
| else: |
| class_embeddings = self.class_embeddings |
| c_crossattn = self.text_adapter(class_embeddings, self.gamma) |
| c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) |
|
|
| |
| if self.pad_shape is not None: |
| pad_width = max(0, self.pad_shape[1] - x.shape[-1]) |
| pad_height = max(0, self.pad_shape[0] - x.shape[-2]) |
| x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) |
|
|
| |
| with torch.no_grad(): |
| latents = self.encoder_vq.encode(x).mode().detach() |
| t = torch.ones((x.shape[0], ), device=x.device).long() |
| outs = self.unet(latents, t, context=c_crossattn) |
|
|
| return outs |
|
|