styledrop / libs /uvit_vq.py
zideliu's picture
StyleDrop init
28c6826
raw history blame
No virus
10.3 kB
import os
import torch
import torch.nn as nn
import math
from loguru import logger
import timm
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import PatchEmbed, Mlp
assert timm.__version__ == "0.3.2" # version check
import einops
import torch.utils.checkpoint
import torch.nn.functional as F
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
self.dropout = nn.Dropout(dropout)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)))
torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
def forward(
self, input_ids
):
input_shape = input_ids.size()
seq_length = input_shape[1]
position_ids = self.position_ids[:, :seq_length]
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class MlmLayer(nn.Module):
def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
super().__init__()
self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
self.gelu = nn.GELU()
self.ln = nn.LayerNorm(word_emb_dim)
self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
def forward(self, x, word_embeddings):
mlm_hidden = self.fc(x)
mlm_hidden = self.gelu(mlm_hidden)
mlm_hidden = self.ln(mlm_hidden)
word_embeddings = word_embeddings.transpose(0, 1)
logits = torch.matmul(mlm_hidden, word_embeddings)
logits = logits + self.bias
return logits
def patchify(imgs, patch_size):
x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
return x
def unpatchify(x, channels=3, flatten=False):
patch_size = int((x.shape[2] // channels) ** 0.5)
h = w = int(x.shape[1] ** .5)
assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
if flatten:
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1 w p2) C', h=h, p1=patch_size, p2=patch_size)
else:
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
if XFORMERS_IS_AVAILBLE:
qkv = self.qkv(x)
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
x = xformers.ops.memory_efficient_attention(q, k, v)
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
else:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
self.use_checkpoint = use_checkpoint
def forward(self, x, skip=None):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
else:
return self._forward(x, skip)
def _forward(self, x, skip=None):
if self.skip_linear is not None:
x = self.skip_linear(torch.cat([x, skip], dim=-1))
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class UViT(nn.Module):
def __init__(self, img_size=16, patch_size=1, in_chans=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, num_classes=-1,
use_checkpoint=False, skip=True, codebook_size=1024):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_classes = num_classes
self.in_chans = in_chans
self.skip = skip
logger.debug(f'codebook size in nnet: {codebook_size}')
self.codebook_size = codebook_size
if num_classes > 0:
self.extras = 1
vocab_size = codebook_size + num_classes + 1
else:
self.extras = 0
vocab_size = codebook_size + 1
self.token_emb = BertEmbeddings(vocab_size=vocab_size,
hidden_size=embed_dim,
max_position_embeddings=int(img_size ** 2) + self.extras,
dropout=0.1)
logger.debug(f'token emb weight shape: {self.token_emb.word_embeddings.weight.shape}')
if patch_size != 1: # downsamp
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, input_shape='bhwc')
logger.debug(f'patch emb weight shape: {self.patch_embed.proj.weight.shape}')
self.decoder_pred = nn.Linear(embed_dim, patch_size ** 2 * embed_dim, bias=True)
else:
self.patch_embed = None
self.decoder_pred = None
self.in_blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
norm_layer=norm_layer, use_checkpoint=use_checkpoint)
for _ in range(depth // 2)])
self.mid_block = Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
norm_layer=norm_layer, use_checkpoint=use_checkpoint)
self.out_blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
for _ in range(depth // 2)])
self.norm = norm_layer(embed_dim)
self.mlm_layer = MlmLayer(feat_emb_dim=embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed'}
def forward(self, x, context=None):
assert len(x.shape) == 2
if context is not None:
context = context + self.codebook_size + 1 # shift, mask token is self.codebook_size
x = torch.cat((context, x), dim=1)
x = self.token_emb(x.long())
if self.patch_embed is not None:
featmap_downsampled = self.patch_embed(
x[:, self.extras:].reshape(-1, *self.patch_embed.img_size, self.embed_dim)).reshape(x.shape[0], -1, self.embed_dim)
x = torch.cat((x[:, :self.extras], featmap_downsampled), dim=1)
if self.skip:
skips = []
for blk in self.in_blocks:
x = blk(x)
if self.skip:
skips.append(x)
x = self.mid_block(x)
for blk in self.out_blocks:
if self.skip:
x = blk(x, skips.pop())
else:
x = blk(x)
x = self.norm(x)
if self.decoder_pred is not None:
featmap_upsampled = unpatchify(self.decoder_pred(x[:, self.extras:]), self.embed_dim, flatten=True)
x = torch.cat((x[:, :self.extras], featmap_upsampled), dim=1)
word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
x = self.mlm_layer(x, word_embeddings)
x = x[:, self.extras:, :self.codebook_size]
return x