import tensorflow as tf from typing import Any, Tuple import tensorflow_addons as tfda class ResidualBlock(tf.keras.layers.Layer): def __init__(self, filter_num: int, filter_size: int, seed: Any = None, name=None, padding="default", instance_normalization: bool = False): super(ResidualBlock, self).__init__(name=name) self.filter_num = filter_num self.filter_size = filter_size self.seed = seed self.padding_type = padding self.activation_1 = tf.keras.layers.Activation("linear", trainable=False) if padding == "default": self.conv_1 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, padding="same", trainable=True) elif padding == "reflect": self.pad_1 = ReflectionPadding2D(padding=(1, 1)) self.conv_1 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, padding="valid", trainable=True) else: raise RuntimeError("Non valid padding type.") self.activation_2 = tf.keras.layers.Activation("relu") if instance_normalization: self.bn_1 = tfda.layers.InstanceNormalization(trainable=True) self.bn_2 = tfda.layers.InstanceNormalization(trainable=True) else: self.bn_1 = tf.keras.layers.BatchNormalization(trainable=True) self.bn_2 = tf.keras.layers.BatchNormalization(trainable=True) if padding == "default": self.conv_2 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, padding="same", trainable=True) elif padding == "reflect": self.pad_2 = ReflectionPadding2D(padding=(1, 1)) self.conv_2 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, padding="valid", trainable=True) else: raise RuntimeError("Non valid padding type.") self.activation_3 = tf.keras.layers.Activation("relu") def call(self, inputs, *args, **kwargs): identity = self.activation_1(inputs) x = identity if self.padding_type == "reflect": x = self.pad_1(x) x = self.conv_1(x) x = self.activation_2(x) x = self.bn_1(x) if self.padding_type == "reflect": x = self.pad_2(x) x = self.conv_2(x) x = self.bn_2(x) residual = tf.keras.layers.Add()([x, identity]) x = self.activation_3(residual) return x class ReflectionPadding2D(tf.keras.layers.Layer): def __init__(self, padding: Tuple[int, int]): super(ReflectionPadding2D, self).__init__() self.pad_width, self.pad_height = padding def call(self, inputs, *args, **kwargs): padding_tensor = tf.constant([ [0, 0], # Batch [self.pad_height, self.pad_height], # Height [self.pad_width, self.pad_width], # Width [0, 0] # Channels ]) return tf.pad(inputs, padding_tensor, mode="REFLECT")