from tensorflow.keras.layers import Layer, InputSpec import keras.utils.conv_utils as conv_utils import tensorflow as tf import tensorflow.keras.backend as K def normalize_data_format(value): if value is None: value = K.image_data_format() data_format = value.lower() if data_format not in {'channels_first', 'channels_last'}: raise ValueError('The `data_format` argument must be one of ' '"channels_first", "channels_last". Received: ' + str(value)) return data_format class BilinearUpSampling2D(Layer): def __init__(self, size=(2, 2), data_format=None, **kwargs): super(BilinearUpSampling2D, self).__init__(**kwargs) self.data_format = normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') self.input_spec = InputSpec(ndim=4) def compute_output_shape(self, input_shape): if self.data_format == 'channels_first': height = self.size[0] * input_shape[2] if input_shape[2] is not None else None width = self.size[1] * input_shape[3] if input_shape[3] is not None else None return (input_shape[0], input_shape[1], height, width) elif self.data_format == 'channels_last': height = self.size[0] * input_shape[1] if input_shape[1] is not None else None width = self.size[1] * input_shape[2] if input_shape[2] is not None else None return (input_shape[0], height, width, input_shape[3]) def call(self, inputs): input_shape = K.shape(inputs) if self.data_format == 'channels_first': height = self.size[0] * input_shape[2] if input_shape[2] is not None else None width = self.size[1] * input_shape[3] if input_shape[3] is not None else None elif self.data_format == 'channels_last': height = self.size[0] * input_shape[1] if input_shape[1] is not None else None width = self.size[1] * input_shape[2] if input_shape[2] is not None else None return tf.image.resize(inputs, [height, width], method=tf.image.ResizeMethod.BILINEAR) def get_config(self): config = {'size': self.size, 'data_format': self.data_format} base_config = super(BilinearUpSampling2D, self).get_config() return dict(list(base_config.items()) + list(config.items()))