gcvit-tf / gcvit /layers /embedding.py
awsaf49's picture
lastest version
4a0cabe
raw
history blame contribute delete
814 Bytes
import tensorflow as tf
from .feature import ReduceSize
@tf.keras.utils.register_keras_serializable(package="gcvit")
class Stem(tf.keras.layers.Layer):
def __init__(self, dim, **kwargs):
super().__init__(**kwargs)
self.dim = dim
def build(self, input_shape):
self.pad = tf.keras.layers.ZeroPadding2D(1, name='pad')
self.proj = tf.keras.layers.Conv2D(self.dim, kernel_size=3, strides=2, name='proj')
self.conv_down = ReduceSize(keep_dim=True, name='conv_down')
super().build(input_shape)
def call(self, inputs, **kwargs):
x = self.pad(inputs)
x = self.proj(x)
x = self.conv_down(x)
return x
def get_config(self):
config = super().get_config()
config.update({'dim': self.dim})
return config