CondViT-B16-cat / module.py
Slep's picture
Upload CondViTForEmbedding
11f6a98 verified
raw
history blame
No virus
4.78 kB
import torch
from torch import nn
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
if self.weight.dtype != x.dtype:
orig_type = x.dtype
ret = super().forward(x.type(self.weight.dtype))
return ret.type(orig_type)
else:
return super().forward(x)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
(
"c_fc",
nn.Linear(d_model, d_model * 4),
),
("gelu", QuickGELU()),
(
"c_proj",
nn.Linear(d_model * 4, d_model),
),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(
x,
x,
x,
need_weights=False,
attn_mask=self.attn_mask,
)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
)
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class ConditionalViT(nn.Module):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
n_categories: int,
):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.n_categories = n_categories
self.c_embedding = nn.Embedding(self.n_categories, width)
self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, imgs: torch.Tensor, c: torch.Tensor = None):
"""
imgs : Batch of images
c : category indices.
"""
x = self.conv1(imgs) # shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# [CLS, grid] + maybe Categories.
tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] # NLD
pos_embed = [self.positional_embedding] # LD
if c is not None: # If c is None, we don't add the token
tokens += [self.c_embedding(c).unsqueeze(1)] # ND -> N1D
pos_embed += [self.c_pos_embedding] # 1D
# shape = [*, grid ** 2 + 1|2, width] = N(L|L+1)D
x = torch.cat(tokens, dim=1)
pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) # 1(L|L+1)D
x = x + pos_embed
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
x = x @ self.proj
return x