Spaces:
Build error
Build error
import tensorflow as tf | |
import tensorflow_addons as tfa | |
import math | |
from tensorflow.keras.layers import ( | |
LayerNormalization, | |
) | |
from .layers import Identity | |
import numpy as np | |
import logging | |
logger = logging.getLogger(__name__) | |
class MultiHeadSelfAttention(tf.keras.layers.Layer): | |
def __init__(self, embed_dim, num_heads=8, attn_drop_rate=0.0, proj_drop=0.0): | |
super(MultiHeadSelfAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
if embed_dim % num_heads != 0: | |
raise ValueError( | |
f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}" | |
) | |
self.projection_dim = embed_dim // num_heads | |
self.query_dense = tf.keras.layers.Dense( | |
units=embed_dim, | |
kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
bias_initializer="zeros", | |
) | |
self.key_dense = tf.keras.layers.Dense( | |
units=embed_dim, | |
kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
bias_initializer="zeros", | |
) | |
self.attn_drop = tf.keras.layers.Dropout(rate=attn_drop_rate) | |
self.value_dense = tf.keras.layers.Dense( | |
units=embed_dim, | |
kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
bias_initializer="zeros", | |
) | |
self.combine_heads = tf.keras.layers.Dense( | |
units=embed_dim, | |
kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
bias_initializer="zeros", | |
) | |
self.proj_drop = tf.keras.layers.Dropout(rate=proj_drop) | |
def attention(self, query, key, value): | |
score = tf.matmul(query, key, transpose_b=True) | |
dim_key = tf.cast(tf.shape(key)[-1], score.dtype) | |
scaled_score = score / tf.math.sqrt(dim_key) | |
weights = tf.nn.softmax(scaled_score, axis=-1) | |
weights = self.attn_drop(weights) | |
output = tf.matmul(weights, value) | |
return output, weights | |
def separate_heads(self, x, batch_size): | |
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim)) | |
return tf.transpose(x, perm=[0, 2, 1, 3]) | |
def call(self, inputs, training): | |
batch_size = tf.shape(inputs)[0] | |
query = self.query_dense(inputs) | |
key = self.key_dense(inputs) | |
value = self.value_dense(inputs) | |
query = self.separate_heads(query, batch_size) | |
key = self.separate_heads(key, batch_size) | |
value = self.separate_heads(value, batch_size) | |
attention, weights = self.attention(query, key, value) | |
attention = tf.transpose(attention, perm=[0, 2, 1, 3]) | |
concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim)) | |
output = self.combine_heads(concat_attention) | |
output = self.proj_drop(output, training=training) | |
return output | |
def get_config(self): | |
config = super().get_config() | |
config.update( | |
{ | |
"num_heads": self.num_heads, | |
"embed_dim": self.embed_dim, | |
} | |
) | |
return config | |
def from_config(cls, config): | |
return cls(**config) | |
class TransformerBlock(tf.keras.layers.Layer): | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
mlp_dim, | |
drop_rate, | |
attn_drop_rate, | |
name="encoderblock", | |
): | |
super(TransformerBlock, self).__init__(name=name) | |
self.att = MultiHeadSelfAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
attn_drop_rate=attn_drop_rate, | |
proj_drop=drop_rate, | |
) | |
self.mlp = tf.keras.Sequential( | |
[ | |
tf.keras.layers.Dense( | |
units=mlp_dim, | |
activation="linear", | |
kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
bias_initializer=tf.keras.initializers.RandomNormal( | |
mean=0.0, stddev=1e-6 | |
), | |
), | |
tf.keras.layers.Lambda( | |
lambda x: tf.keras.activations.gelu(x, approximate=True) | |
), | |
tf.keras.layers.Dropout(rate=drop_rate), | |
tf.keras.layers.Dense( | |
units=embed_dim, | |
kernel_initializer=tf.keras.initializers.GlorotUniform(), | |
bias_initializer=tf.keras.initializers.RandomNormal( | |
mean=0.0, stddev=1e-6 | |
), | |
), | |
tf.keras.layers.Dropout(rate=drop_rate), | |
] | |
) | |
self.layernorm1 = LayerNormalization(epsilon=1e-6) | |
self.layernorm2 = LayerNormalization(epsilon=1e-6) | |
def call(self, inputs, training): | |
inputs_norm = self.layernorm1(inputs) | |
attn_output = self.att(inputs_norm) | |
out1 = attn_output + inputs | |
out1_norm = self.layernorm2(out1) | |
mlp_output = self.mlp(out1_norm) | |
return out1 + mlp_output | |
def get_config(self): | |
config = super().get_config() | |
return config | |
def from_config(cls, config): | |
return cls(**config) | |
class VisionTransformer(tf.keras.Model): | |
def __init__( | |
self, | |
image_size, | |
patch_size, | |
num_layers, | |
hidden_size, | |
num_heads, | |
mlp_dim, | |
representation_size=None, | |
channels=3, | |
dropout_rate=0.1, | |
attention_dropout_rate=0.0, | |
num_classes=None, | |
): | |
super(VisionTransformer, self).__init__() | |
num_patches = (image_size // patch_size) ** 2 | |
self.patch_dim = channels * patch_size ** 2 | |
self.patch_size = patch_size | |
self.d_model = hidden_size | |
self.num_layers = num_layers | |
self.class_emb = self.add_weight( | |
"class_emb", | |
shape=(1, 1, hidden_size), | |
initializer=tf.keras.initializers.Zeros(), | |
trainable=True, | |
) | |
self.pos_emb = self.add_weight( | |
"pos_emb", | |
shape=(1, num_patches + 1, hidden_size), | |
initializer=tf.keras.initializers.RandomNormal( | |
mean=0.0, stddev=0.02, seed=None | |
), | |
trainable=True, | |
) | |
self.pos_drop = tf.keras.layers.Dropout(rate=dropout_rate, name="pos_drop") | |
self.embedding = tf.keras.layers.Conv2D( | |
filters=hidden_size, | |
kernel_size=self.patch_size, | |
strides=self.patch_size, | |
padding="valid", | |
name="embedding", | |
) | |
self.enc_layers = [ | |
TransformerBlock( | |
embed_dim=hidden_size, | |
num_heads=num_heads, | |
mlp_dim=mlp_dim, | |
drop_rate=dropout_rate, | |
attn_drop_rate=attention_dropout_rate, | |
name=f"encoderblock_{i}", | |
) | |
for i in range(num_layers) | |
] | |
self.norm = LayerNormalization(epsilon=1e-6, name="encoder_nrom") | |
self.extract_token = tf.keras.layers.Lambda( | |
lambda x: x[:, 0], name="extract_token" | |
) | |
self.representation = ( | |
tf.keras.layers.Dense( | |
units=representation_size, | |
activation="tanh", | |
name="pre_logits", | |
) | |
if representation_size != 0 | |
else Identity(name=f"pre_logits") | |
) | |
def call(self, x, training): | |
batch_size = tf.shape(x)[0] | |
x = self.embedding(x) | |
x = tf.reshape(x, [batch_size, -1, self.d_model]) | |
class_emb = tf.broadcast_to(self.class_emb, [batch_size, 1, self.d_model]) | |
# B x (N + 1) x d_model | |
x = tf.concat([tf.cast(class_emb, x.dtype), x], axis=1) | |
x = x + tf.cast(self.pos_emb, x.dtype) | |
# https://github.com/google-research/vision_transformer/blob/39c905d2caf96a4306c9d78f05df36ddb3eb8ecb/vit_jax/models.py#L192 | |
x = self.pos_drop(x, training=training) | |
for layer in self.enc_layers: | |
x = layer(x, training) | |
x = self.norm(x) | |
# First (class token) is used for classification | |
x = self.extract_token(x) | |
x = self.representation(x) | |
return x[:, tf.newaxis, tf.newaxis, :] | |
KNOWN_MODELS = { | |
"ti": { | |
"num_layers": 12, | |
"hidden_size": 192, | |
"num_heads": 3, | |
"mlp_dim": 768, | |
}, | |
"s": { | |
"num_layers": 12, | |
"hidden_size": 384, | |
"num_heads": 6, | |
"mlp_dim": 1536, | |
}, | |
"b": { | |
"num_layers": 12, | |
"hidden_size": 768, | |
"num_heads": 12, | |
"mlp_dim": 3072, | |
}, | |
"l": { | |
"num_layers": 24, | |
"hidden_size": 1024, | |
"num_heads": 16, | |
"mlp_dim": 4096, | |
}, | |
} | |
def create_name_vit(architecture_name, **kwargs): | |
base, patch_size = [l.lower() for l in architecture_name.split("-")[-1].split("/")] | |
return VisionTransformer( | |
patch_size=int(patch_size), | |
**KNOWN_MODELS[base], | |
**kwargs, | |
) |