Spaces:
Runtime error
Runtime error
import einops | |
import tensorflow as tf | |
from tensorflow.experimental import numpy as tnp | |
from tensorflow.keras import backend as K | |
from tensorflow.keras import layers | |
class BlockImages(layers.Layer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def call(self, x, patch_size): | |
bs, h, w, num_channels = ( | |
K.int_shape(x)[0], | |
K.int_shape(x)[1], | |
K.int_shape(x)[2], | |
K.int_shape(x)[3], | |
) | |
grid_height, grid_width = h // patch_size[0], w // patch_size[1] | |
x = einops.rearrange( | |
x, | |
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c", | |
gh=grid_height, | |
gw=grid_width, | |
fh=patch_size[0], | |
fw=patch_size[1], | |
) | |
return x | |
def get_config(self): | |
config = super().get_config().copy() | |
return config | |
class UnblockImages(layers.Layer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def call(self, x, grid_size, patch_size): | |
x = einops.rearrange( | |
x, | |
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c", | |
gh=grid_size[0], | |
gw=grid_size[1], | |
fh=patch_size[0], | |
fw=patch_size[1], | |
) | |
return x | |
def get_config(self): | |
config = super().get_config().copy() | |
return config | |
class SwapAxes(layers.Layer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def call(self, x, axis_one, axis_two): | |
return tnp.swapaxes(x, axis_one, axis_two) | |
def get_config(self): | |
config = super().get_config().copy() | |
return config | |
class Resizing(layers.Layer): | |
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs): | |
super().__init__(**kwargs) | |
self.height = height | |
self.width = width | |
self.antialias = antialias | |
self.method = method | |
def call(self, x): | |
return tf.image.resize( | |
x, | |
size=(self.height, self.width), | |
antialias=self.antialias, | |
method=self.method, | |
) | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update( | |
{ | |
"height": self.height, | |
"width": self.width, | |
"antialias": self.antialias, | |
"method": self.method, | |
} | |
) | |
return config | |