Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
4.65 kB
from typing import List
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow_addons.layers import GroupNormalization
from .layers import ResnetBlock, AttentionBlock, Downsample
# @tf.keras.utils.register_keras_serializable()
class Encoder(layers.Layer):
def __init__(
self,
*,
channels: int,
channels_multiplier: List[int],
num_res_blocks: int,
attention_resolution: List[int],
resolution: int,
z_channels: int,
dropout: float,
**kwargs
):
"""Encode an image into a latent vector. The encoder will be constitued of multiple levels (lenght of `channels_multiplier`) with for each level `num_res_blocks` ResnetBlock.
Args:
channels (int, optional): The number of channel for the first layer. Defaults to 128.
channels_multiplier (List[int], optional): The channel multiplier for each level (previous level channels X multipler). Defaults to [1, 1, 2, 2].
num_res_blocks (int, optional): Number of ResnetBlock at each level. Defaults to 1.
attention_resolution (List[int], optional): Add an attention block if the current resolution is in this array. Defaults to [16].
resolution (int, optional): The starting resolution. Defaults to 64.
z_channels (int, optional): The number of channel at the end of the encoder. Defaults to 128.
dropout (float, optional): The dropout ratio for each ResnetBlock. Defaults to 0.0.
"""
super().__init__(**kwargs)
self.channels = channels
self.channels_multiplier = channels_multiplier
self.num_resolutions = len(channels_multiplier)
self.num_res_blocks = num_res_blocks
self.attention_resolution = attention_resolution
self.resolution = resolution
self.z_channels = z_channels
self.dropout = dropout
self.conv_in = layers.Conv2D(
self.channels, kernel_size=3, strides=1, padding="same"
)
current_resolution = resolution
in_channels_multiplier = (1,) + tuple(channels_multiplier)
self.downsampling_list = []
for i_level in range(self.num_resolutions):
block_in = channels * in_channels_multiplier[i_level]
block_out = channels * channels_multiplier[i_level]
for i_block in range(self.num_res_blocks):
self.downsampling_list.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
)
)
block_in = block_out
if current_resolution in attention_resolution:
self.downsampling_list.append(AttentionBlock(block_in))
if i_level != self.num_resolutions - 1:
self.downsampling_list.append(Downsample(block_in))
current_resolution = current_resolution // 2
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
self.mid["attn_1"] = AttentionBlock(block_in)
self.mid["block_2"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
# end
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
self.conv_out = layers.Conv2D(
z_channels,
kernel_size=3,
strides=1,
padding="same",
)
# def get_config(self):
# config = super().get_config()
# config.update(
# {
# "channels": self.channels,
# "channels_multiplier": self.channels_multiplier,
# "num_res_blocks": self.num_res_blocks,
# "attention_resolution": self.attention_resolution,
# "resolution": self.resolution,
# "z_channels": self.z_channels,
# "dropout": self.dropout,
# }
# )
# return config
def call(self, inputs, training=True, mask=None):
h = self.conv_in(inputs)
for downsampling in self.downsampling_list:
h = downsampling(h)
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = keras.activations.swish(h)
h = self.conv_out(h)
return h