import math import torch import numpy as np from torch import nn class Attention2D(nn.Module): def __init__(self, c, nhead, dropout=0.0): super().__init__() self.attn = torch.nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) def forward(self, x, kv, self_attn=False): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) if self_attn: kv = torch.cat([x, kv], dim=1) x = self.attn(x, kv, kv, need_weights=False)[0] x = x.permute(0, 2, 1).view(*orig_shape) return x class LayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class GlobalResponseNorm(nn.Module): "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x class ResBlock(nn.Module): def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) ) def forward(self, x, x_skip=None): x_res = x if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) x = self.channelwise(x).permute(0, 3, 1, 2) return x + x_res class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() self.self_attn = self_attn self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) self.attention = Attention2D(c, nhead, dropout) self.kv_mapper = nn.Sequential( nn.SiLU(), nn.Linear(c_cond, c) ) def forward(self, x, kv): kv = self.kv_mapper(kv) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) return x class FeedForwardBlock(nn.Module): def __init__(self, c, dropout=0.0): super().__init__() self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) ) def forward(self, x): x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() self.mapper = nn.Linear(c_timestep, c * 2) def forward(self, x, t): a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) return x * (1 + a) + b class Paella(nn.Module): def __init__(self, c_in=256, c_out=256, num_labels=8192, c_r=64, patch_size=2, c_cond=1024, c_hidden=[640, 1280, 1280], nhead=[-1, 16, 16], blocks=[6, 16, 6], level_config=['CT', 'CTA', 'CTA'], clip_embd=1024, byt5_embd=1536, clip_seq_len=4, kernel_size=3, dropout=0.1, self_attn=True): super().__init__() self.c_r = c_r self.c_cond = c_cond self.num_labels = num_labels if not isinstance(dropout, list): dropout = [dropout] * len(c_hidden) # CONDITIONING self.byt5_mapper = nn.Linear(byt5_embd, c_cond) self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len) self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len) self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) self.in_mapper = nn.Sequential( nn.Embedding(num_labels, c_in), nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6) ) self.embedding = nn.Sequential( nn.PixelUnshuffle(patch_size), nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) ) def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): if block_type == 'C': return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == 'A': return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) elif block_type == 'F': return FeedForwardBlock(c_hidden, dropout=dropout) elif block_type == 'T': return TimestepBlock(c_hidden, c_r) else: raise Exception(f'Block type {block_type} not supported') # DOWN BLOCKS self.down_blocks = nn.ModuleList() for i in range(len(c_hidden)): down_block = nn.ModuleList() if i > 0: down_block.append(nn.Sequential( LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), )) for _ in range(blocks[i]): for block_type in level_config[i]: down_block.append(get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i])) self.down_blocks.append(down_block) # UP BLOCKS self.up_blocks = nn.ModuleList() for i in reversed(range(len(c_hidden))): up_block = nn.ModuleList() for j in range(blocks[i]): for k, block_type in enumerate(level_config[i]): up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0, dropout=dropout[i])) if i > 0: up_block.append(nn.Sequential( LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), )) self.up_blocks.append(up_block) # OUTPUT self.clf = nn.Sequential( LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), nn.PixelShuffle(patch_size), ) self.out_mapper = nn.Sequential( LayerNorm2d(c_out, elementwise_affine=False, eps=1e-6), nn.Conv2d(c_out, num_labels, kernel_size=1, bias=False) ) # --- WEIGHT INIT --- self.apply(self._init_weights) # General init nn.init.normal_(self.byt5_mapper.weight, std=0.02) nn.init.normal_(self.clip_mapper.weight, std=0.02) nn.init.normal_(self.clip_image_mapper.weight, std=0.02) torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) nn.init.constant_(self.clf[1].weight, 0) nn.init.normal_(self.in_mapper[0].weight, std=np.sqrt(1 / num_labels)) self.out_mapper[-1].weight.data = self.in_mapper[0].weight.data[:, :, None, None].clone() for level_block in self.down_blocks + self.up_blocks: for block in level_block: if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks)) elif isinstance(block, TimestepBlock): nn.init.constant_(block.mapper.weight, 0) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: emb = nn.functional.pad(emb, (0, 1), mode='constant') return emb def gen_c_embeddings(self, byt5, clip, clip_image): seq = self.byt5_mapper(byt5) if clip is not None: clip = self.clip_mapper(clip).view(clip.size(0), -1, self.c_cond) seq = torch.cat([seq, clip], dim=1) if clip_image is not None: clip_image = self.clip_image_mapper(clip_image).view(clip_image.size(0), -1, self.c_cond) seq = torch.cat([seq, clip_image], dim=1) seq = self.seq_norm(seq) return seq def _down_encode(self, x, r_embed, c_embed): level_outputs = [] for down_block in self.down_blocks: for block in down_block: if isinstance(block, ResBlock): x = block(x) elif isinstance(block, AttnBlock): x = block(x, c_embed) elif isinstance(block, TimestepBlock): x = block(x, r_embed) else: x = block(x) level_outputs.insert(0, x) return level_outputs def _up_decode(self, level_outputs, r_embed, c_embed): x = level_outputs[0] for i, up_block in enumerate(self.up_blocks): for j, block in enumerate(up_block): if isinstance(block, ResBlock): x = block(x, level_outputs[i] if j == 0 and i > 0 else None) elif isinstance(block, AttnBlock): x = block(x, c_embed) elif isinstance(block, TimestepBlock): x = block(x, r_embed) else: x = block(x) return x def forward(self, x, r, byt5, clip=None, clip_image=None, x_cat=None): if x_cat is not None: x = torch.cat([x, x_cat], dim=1) # Process the conditioning embeddings r_embed = self.gen_r_embedding(r) c_embed = self.gen_c_embeddings(byt5, clip, clip_image) # Model Blocks x = self.embedding(self.in_mapper(x).permute(0, 3, 1, 2)) level_outputs = self._down_encode(x, r_embed, c_embed) x = self._up_decode(level_outputs, r_embed, c_embed) x = self.out_mapper(self.clf(x)) return x def add_noise(self, x, t, mask=None, random_x=None): if mask is None: mask = (torch.rand_like(x.float()) <= t[:, None, None]).long() if random_x is None: random_x = torch.randint_like(x, 0, self.num_labels) x = x * (1 - mask) + random_x * mask return x, mask