|
import einops |
|
import tensorflow as tf |
|
from tensorflow.experimental import numpy as tnp |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras import layers |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable("maxim") |
|
class BlockImages(layers.Layer): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def call(self, x, patch_size): |
|
bs, h, w, num_channels = ( |
|
K.int_shape(x)[0], |
|
K.int_shape(x)[1], |
|
K.int_shape(x)[2], |
|
K.int_shape(x)[3], |
|
) |
|
|
|
grid_height, grid_width = h // patch_size[0], w // patch_size[1] |
|
|
|
x = einops.rearrange( |
|
x, |
|
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c", |
|
gh=grid_height, |
|
gw=grid_width, |
|
fh=patch_size[0], |
|
fw=patch_size[1], |
|
) |
|
|
|
return x |
|
|
|
def get_config(self): |
|
config = super().get_config().copy() |
|
return config |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable("maxim") |
|
class UnblockImages(layers.Layer): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def call(self, x, grid_size, patch_size): |
|
x = einops.rearrange( |
|
x, |
|
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c", |
|
gh=grid_size[0], |
|
gw=grid_size[1], |
|
fh=patch_size[0], |
|
fw=patch_size[1], |
|
) |
|
|
|
return x |
|
|
|
def get_config(self): |
|
config = super().get_config().copy() |
|
return config |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable("maxim") |
|
class SwapAxes(layers.Layer): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def call(self, x, axis_one, axis_two): |
|
return tnp.swapaxes(x, axis_one, axis_two) |
|
|
|
def get_config(self): |
|
config = super().get_config().copy() |
|
return config |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable("maxim") |
|
class Resizing(layers.Layer): |
|
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs): |
|
super().__init__(**kwargs) |
|
self.height = height |
|
self.width = width |
|
self.antialias = antialias |
|
self.method = method |
|
|
|
def call(self, x): |
|
return tf.image.resize( |
|
x, |
|
size=(self.height, self.width), |
|
antialias=self.antialias, |
|
method=self.method, |
|
) |
|
|
|
def get_config(self): |
|
config = super().get_config().copy() |
|
config.update( |
|
{ |
|
"height": self.height, |
|
"width": self.width, |
|
"antialias": self.antialias, |
|
"method": self.method, |
|
} |
|
) |
|
return config |
|
|