import tensorflow as tf def window_partition(x, window_size): _, height, width, channels = x.shape patch_num_y = height // window_size patch_num_x = width // window_size x = tf.reshape( x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels) ) x = tf.transpose(x, (0, 1, 3, 2, 4, 5)) windows = tf.reshape(x, shape=(-1, window_size, window_size, channels)) return windows def window_reverse(windows, window_size, height, width, channels): patch_num_y = height // window_size patch_num_x = width // window_size x = tf.reshape( windows, shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels), ) x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5)) x = tf.reshape(x, shape=(-1, height, width, channels)) return x