kakao-brain-vit / backbone /vit_model.py
adirik's picture
init app
49bc02a
import tensorflow as tf
import tensorflow_addons as tfa
import math
from tensorflow.keras.layers import (
LayerNormalization,
)
from .layers import Identity
import numpy as np
import logging
logger = logging.getLogger(__name__)
class MultiHeadSelfAttention(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads=8, attn_drop_rate=0.0, proj_drop=0.0):
super(MultiHeadSelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
if embed_dim % num_heads != 0:
raise ValueError(
f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
)
self.projection_dim = embed_dim // num_heads
self.query_dense = tf.keras.layers.Dense(
units=embed_dim,
kernel_initializer=tf.keras.initializers.GlorotNormal(),
bias_initializer="zeros",
)
self.key_dense = tf.keras.layers.Dense(
units=embed_dim,
kernel_initializer=tf.keras.initializers.GlorotNormal(),
bias_initializer="zeros",
)
self.attn_drop = tf.keras.layers.Dropout(rate=attn_drop_rate)
self.value_dense = tf.keras.layers.Dense(
units=embed_dim,
kernel_initializer=tf.keras.initializers.GlorotNormal(),
bias_initializer="zeros",
)
self.combine_heads = tf.keras.layers.Dense(
units=embed_dim,
kernel_initializer=tf.keras.initializers.GlorotNormal(),
bias_initializer="zeros",
)
self.proj_drop = tf.keras.layers.Dropout(rate=proj_drop)
def attention(self, query, key, value):
score = tf.matmul(query, key, transpose_b=True)
dim_key = tf.cast(tf.shape(key)[-1], score.dtype)
scaled_score = score / tf.math.sqrt(dim_key)
weights = tf.nn.softmax(scaled_score, axis=-1)
weights = self.attn_drop(weights)
output = tf.matmul(weights, value)
return output, weights
def separate_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training):
batch_size = tf.shape(inputs)[0]
query = self.query_dense(inputs)
key = self.key_dense(inputs)
value = self.value_dense(inputs)
query = self.separate_heads(query, batch_size)
key = self.separate_heads(key, batch_size)
value = self.separate_heads(value, batch_size)
attention, weights = self.attention(query, key, value)
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))
output = self.combine_heads(concat_attention)
output = self.proj_drop(output, training=training)
return output
def get_config(self):
config = super().get_config()
config.update(
{
"num_heads": self.num_heads,
"embed_dim": self.embed_dim,
}
)
return config
@classmethod
def from_config(cls, config):
return cls(**config)
class TransformerBlock(tf.keras.layers.Layer):
def __init__(
self,
embed_dim,
num_heads,
mlp_dim,
drop_rate,
attn_drop_rate,
name="encoderblock",
):
super(TransformerBlock, self).__init__(name=name)
self.att = MultiHeadSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
attn_drop_rate=attn_drop_rate,
proj_drop=drop_rate,
)
self.mlp = tf.keras.Sequential(
[
tf.keras.layers.Dense(
units=mlp_dim,
activation="linear",
kernel_initializer=tf.keras.initializers.GlorotNormal(),
bias_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=1e-6
),
),
tf.keras.layers.Lambda(
lambda x: tf.keras.activations.gelu(x, approximate=True)
),
tf.keras.layers.Dropout(rate=drop_rate),
tf.keras.layers.Dense(
units=embed_dim,
kernel_initializer=tf.keras.initializers.GlorotUniform(),
bias_initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=1e-6
),
),
tf.keras.layers.Dropout(rate=drop_rate),
]
)
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
def call(self, inputs, training):
inputs_norm = self.layernorm1(inputs)
attn_output = self.att(inputs_norm)
out1 = attn_output + inputs
out1_norm = self.layernorm2(out1)
mlp_output = self.mlp(out1_norm)
return out1 + mlp_output
def get_config(self):
config = super().get_config()
return config
@classmethod
def from_config(cls, config):
return cls(**config)
class VisionTransformer(tf.keras.Model):
def __init__(
self,
image_size,
patch_size,
num_layers,
hidden_size,
num_heads,
mlp_dim,
representation_size=None,
channels=3,
dropout_rate=0.1,
attention_dropout_rate=0.0,
num_classes=None,
):
super(VisionTransformer, self).__init__()
num_patches = (image_size // patch_size) ** 2
self.patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.d_model = hidden_size
self.num_layers = num_layers
self.class_emb = self.add_weight(
"class_emb",
shape=(1, 1, hidden_size),
initializer=tf.keras.initializers.Zeros(),
trainable=True,
)
self.pos_emb = self.add_weight(
"pos_emb",
shape=(1, num_patches + 1, hidden_size),
initializer=tf.keras.initializers.RandomNormal(
mean=0.0, stddev=0.02, seed=None
),
trainable=True,
)
self.pos_drop = tf.keras.layers.Dropout(rate=dropout_rate, name="pos_drop")
self.embedding = tf.keras.layers.Conv2D(
filters=hidden_size,
kernel_size=self.patch_size,
strides=self.patch_size,
padding="valid",
name="embedding",
)
self.enc_layers = [
TransformerBlock(
embed_dim=hidden_size,
num_heads=num_heads,
mlp_dim=mlp_dim,
drop_rate=dropout_rate,
attn_drop_rate=attention_dropout_rate,
name=f"encoderblock_{i}",
)
for i in range(num_layers)
]
self.norm = LayerNormalization(epsilon=1e-6, name="encoder_nrom")
self.extract_token = tf.keras.layers.Lambda(
lambda x: x[:, 0], name="extract_token"
)
self.representation = (
tf.keras.layers.Dense(
units=representation_size,
activation="tanh",
name="pre_logits",
)
if representation_size != 0
else Identity(name=f"pre_logits")
)
def call(self, x, training):
batch_size = tf.shape(x)[0]
x = self.embedding(x)
x = tf.reshape(x, [batch_size, -1, self.d_model])
class_emb = tf.broadcast_to(self.class_emb, [batch_size, 1, self.d_model])
# B x (N + 1) x d_model
x = tf.concat([tf.cast(class_emb, x.dtype), x], axis=1)
x = x + tf.cast(self.pos_emb, x.dtype)
# https://github.com/google-research/vision_transformer/blob/39c905d2caf96a4306c9d78f05df36ddb3eb8ecb/vit_jax/models.py#L192
x = self.pos_drop(x, training=training)
for layer in self.enc_layers:
x = layer(x, training)
x = self.norm(x)
# First (class token) is used for classification
x = self.extract_token(x)
x = self.representation(x)
return x[:, tf.newaxis, tf.newaxis, :]
KNOWN_MODELS = {
"ti": {
"num_layers": 12,
"hidden_size": 192,
"num_heads": 3,
"mlp_dim": 768,
},
"s": {
"num_layers": 12,
"hidden_size": 384,
"num_heads": 6,
"mlp_dim": 1536,
},
"b": {
"num_layers": 12,
"hidden_size": 768,
"num_heads": 12,
"mlp_dim": 3072,
},
"l": {
"num_layers": 24,
"hidden_size": 1024,
"num_heads": 16,
"mlp_dim": 4096,
},
}
def create_name_vit(architecture_name, **kwargs):
base, patch_size = [l.lower() for l in architecture_name.split("-")[-1].split("/")]
return VisionTransformer(
patch_size=int(patch_size),
**KNOWN_MODELS[base],
**kwargs,
)