|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_addons as tfa |
|
import keras |
|
from keras import Model, Sequential, initializers |
|
from keras.layers import Layer, Conv2D, LeakyReLU, Dropout |
|
|
|
|
|
class SPADE(Layer): |
|
def __init__(self, filters: int, epsilon=1e-5, **kwargs): |
|
super().__init__(**kwargs) |
|
self.epsilon = epsilon |
|
self.conv = Conv2D(128, 3, padding="same", activation="relu") |
|
self.conv_gamma = Conv2D(filters, 3, padding="same") |
|
self.conv_beta = Conv2D(filters, 3, padding="same") |
|
|
|
def build(self, input_shape): |
|
self.resize_shape = input_shape[1:3] |
|
|
|
def call(self, input_tensor, raw_mask): |
|
mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest") |
|
x = self.conv(mask) |
|
gamma = self.conv_gamma(x) |
|
beta = self.conv_beta(x) |
|
mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True) |
|
std = tf.sqrt(var + self.epsilon) |
|
normalized = (input_tensor - mean) / std |
|
output = gamma * normalized + beta |
|
return output |
|
|
|
def get_config(self): |
|
return { |
|
"epsilon": self.epsilon, |
|
"conv": self.conv, |
|
"conv_gamma": self.conv_gamma, |
|
"conv_beta": self.conv_beta |
|
} |
|
|
|
|
|
class ResBlock(Layer): |
|
def __init__(self, filters: int, **kwargs): |
|
super().__init__(**kwargs) |
|
self.filters = filters |
|
|
|
def build(self, input_shape): |
|
input_filter = input_shape[-1] |
|
self.spade_1 = SPADE(input_filter) |
|
self.spade_2 = SPADE(self.filters) |
|
self.conv_1 = Conv2D(self.filters, 3, padding="same") |
|
self.conv_2 = Conv2D(self.filters, 3, padding="same") |
|
self.leaky_relu = LeakyReLU(0.2) |
|
self.learned_skip = False |
|
|
|
if self.filters != input_filter: |
|
self.learned_skip = True |
|
self.spade_3 = SPADE(input_filter) |
|
self.conv_3 = Conv2D(self.filters, 3, padding="same") |
|
|
|
def call(self, input_tensor, mask): |
|
x = self.spade_1(input_tensor, mask) |
|
x = self.conv_1(self.leaky_relu(x)) |
|
x = self.spade_2(x, mask) |
|
x = self.conv_2(self.leaky_relu(x)) |
|
skip = ( |
|
self.conv_3(self.leaky_relu(self.spade_3(input_tensor, mask))) |
|
if self.learned_skip |
|
else input_tensor |
|
) |
|
output = skip + x |
|
return output |
|
|
|
def get_config(self): |
|
return {"filters": self.filters} |
|
|
|
|
|
class Downsample(Layer): |
|
def __init__(self, |
|
channels: int, |
|
kernels: int, |
|
strides: int = 2, |
|
apply_norm=True, |
|
apply_activation=True, |
|
apply_dropout=False, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.channels = channels |
|
self.kernels = kernels |
|
self.strides = strides |
|
self.apply_norm = apply_norm |
|
self.apply_activation = apply_activation |
|
self.apply_dropout = apply_dropout |
|
|
|
def build(self, input_shape): |
|
self.block = Sequential([ |
|
Conv2D( |
|
self.channels, |
|
self.kernels, |
|
strides=self.strides, |
|
padding="same", |
|
use_bias=False, |
|
kernel_initializer=initializers.GlorotNormal(), |
|
)]) |
|
if self.apply_norm: |
|
self.block.add(tfa.layers.InstanceNormalization()) |
|
if self.apply_activation: |
|
self.block.add(LeakyReLU(0.2)) |
|
if self.apply_dropout: |
|
self.block.add(Dropout(0.5)) |
|
|
|
def call(self, inputs): |
|
return self.block(inputs) |
|
|
|
def get_config(self): |
|
return { |
|
"channels": self.channels, |
|
"kernels": self.kernels, |
|
"strides": self.strides, |
|
"apply_norm": self.apply_norm, |
|
"apply_activation": self.apply_activation, |
|
"apply_dropout": self.apply_dropout, |
|
} |
|
|
|
|
|
class GaussianSampler(Layer): |
|
def __init__(self, latent_dim: int, **kwargs): |
|
super().__init__(**kwargs) |
|
self.latent_dim = latent_dim |
|
|
|
def call(self, inputs): |
|
means, variance = inputs |
|
epsilon = tf.random.normal( |
|
shape=(tf.shape(means)[0], self.latent_dim), mean=0.0, stddev=1.0 |
|
) |
|
samples = means + tf.exp(0.5 * variance) * epsilon |
|
return samples |
|
|
|
def get_config(self): |
|
return {"latent_dim": self.latent_dim} |
|
|
|
|
|
class GauganPredictor(): |
|
|
|
CLASSES = ( |
|
'unknown','wall', 'sky', 'tree', 'road', 'grass', 'earth', |
|
'mountain', 'plant', 'water', 'sea', 'field', 'fence', 'rock', |
|
'sand', 'path', 'river', 'flower', 'hill', 'palm', 'tower', |
|
'dirt', 'land', 'waterfall', 'lake' |
|
) |
|
|
|
def __init__(self, model_g_path: str, model_e_path: str = None) -> None: |
|
custom_objects = { |
|
'ResBlock': ResBlock, |
|
'Downsample': Downsample, |
|
} |
|
if model_e_path is not None: |
|
self.encoder: Model = keras.models.load_model(model_e_path, custom_objects=custom_objects) |
|
self.sampler = GaussianSampler(256) |
|
self.gen: Model = keras.models.load_model( |
|
model_g_path, custom_objects=custom_objects) |
|
|
|
def __call__(self, im: np.ndarray, z=None) -> np.ndarray: |
|
if len(im.shape) == 3: |
|
im = im[np.newaxis] |
|
if z is None: |
|
z = tf.random.normal((im.shape[0], 256)) |
|
tmp = self.gen.predict_on_batch([z, im]) |
|
x = np.array((tmp + 1) * 127.5, np.uint8) |
|
return x |
|
|
|
def predict_reference(self, im: np.ndarray, reference_im: np.ndarray) -> np.ndarray: |
|
if len(im.shape) == 3: |
|
im = im[np.newaxis] |
|
reference_im = reference_im[np.newaxis] |
|
mean, variance = self.encoder(reference_im) |
|
z = self.sampler([mean, variance]) |
|
x = np.array((self.gen.predict_on_batch([z, im]) + 1) * 127.5, np.uint8) |
|
return x |
|
|