Spaces:
Runtime error
Runtime error
File size: 2,668 Bytes
ef31f77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import einops
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from tensorflow.keras import backend as K
from tensorflow.keras import layers
@tf.keras.utils.register_keras_serializable("maxim")
class BlockImages(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, x, patch_size):
bs, h, w, num_channels = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
)
grid_height, grid_width = h // patch_size[0], w // patch_size[1]
x = einops.rearrange(
x,
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
gh=grid_height,
gw=grid_width,
fh=patch_size[0],
fw=patch_size[1],
)
return x
def get_config(self):
config = super().get_config().copy()
return config
@tf.keras.utils.register_keras_serializable("maxim")
class UnblockImages(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, x, grid_size, patch_size):
x = einops.rearrange(
x,
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
gh=grid_size[0],
gw=grid_size[1],
fh=patch_size[0],
fw=patch_size[1],
)
return x
def get_config(self):
config = super().get_config().copy()
return config
@tf.keras.utils.register_keras_serializable("maxim")
class SwapAxes(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, x, axis_one, axis_two):
return tnp.swapaxes(x, axis_one, axis_two)
def get_config(self):
config = super().get_config().copy()
return config
@tf.keras.utils.register_keras_serializable("maxim")
class Resizing(layers.Layer):
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
super().__init__(**kwargs)
self.height = height
self.width = width
self.antialias = antialias
self.method = method
def call(self, x):
return tf.image.resize(
x,
size=(self.height, self.width),
antialias=self.antialias,
method=self.method,
)
def get_config(self):
config = super().get_config().copy()
config.update(
{
"height": self.height,
"width": self.width,
"antialias": self.antialias,
"method": self.method,
}
)
return config
|