chansung's picture
add custom handler
aae13dc
raw
history blame
13 kB
from typing import Dict, List, Any
import base64
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.constants import _ALPHAS_CUMPROD
from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel
class GroupNormalization(tf.keras.layers.Layer):
"""GroupNormalization layer.
This layer is only here temporarily and will be removed
as we introduce GroupNormalization in core Keras.
"""
def __init__(
self,
groups=32,
axis=-1,
epsilon=1e-5,
**kwargs,
):
super().__init__(**kwargs)
self.groups = groups
self.axis = axis
self.epsilon = epsilon
def build(self, input_shape):
dim = input_shape[self.axis]
self.gamma = self.add_weight(
shape=(dim,),
name="gamma",
initializer="ones",
)
self.beta = self.add_weight(
shape=(dim,),
name="beta",
initializer="zeros",
)
def call(self, inputs):
input_shape = tf.shape(inputs)
reshaped_inputs = self._reshape_into_groups(inputs, input_shape)
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
return tf.reshape(normalized_inputs, input_shape)
def _reshape_into_groups(self, inputs, input_shape):
group_shape = [input_shape[i] for i in range(inputs.shape.rank)]
group_shape[self.axis] = input_shape[self.axis] // self.groups
group_shape.insert(self.axis, self.groups)
group_shape = tf.stack(group_shape)
return tf.reshape(inputs, group_shape)
def _apply_normalization(self, reshaped_inputs, input_shape):
group_reduction_axes = list(range(1, reshaped_inputs.shape.rank))
axis = -2 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)
mean, variance = tf.nn.moments(
reshaped_inputs, group_reduction_axes, keepdims=True
)
gamma, beta = self._get_reshaped_weights(input_shape)
return tf.nn.batch_normalization(
reshaped_inputs,
mean=mean,
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=self.epsilon,
)
def _get_reshaped_weights(self, input_shape):
broadcast_shape = self._create_broadcast_shape(input_shape)
gamma = tf.reshape(self.gamma, broadcast_shape)
beta = tf.reshape(self.beta, broadcast_shape)
return gamma, beta
def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * input_shape.shape.rank
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(self.axis, self.groups)
return broadcast_shape
class PaddedConv2D(keras.layers.Layer):
def __init__(self, filters, kernel_size, padding=0, strides=1, **kwargs):
super().__init__(**kwargs)
self.padding2d = keras.layers.ZeroPadding2D(padding)
self.conv2d = keras.layers.Conv2D(filters, kernel_size, strides=strides)
def call(self, inputs):
x = self.padding2d(inputs)
return self.conv2d(x)
class AttentionBlock(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.norm = GroupNormalization(epsilon=1e-5)
self.q = PaddedConv2D(output_dim, 1)
self.k = PaddedConv2D(output_dim, 1)
self.v = PaddedConv2D(output_dim, 1)
self.proj_out = PaddedConv2D(output_dim, 1)
def call(self, inputs):
x = self.norm(inputs)
q, k, v = self.q(x), self.k(x), self.v(x)
# Compute attention
_, h, w, c = q.shape
q = tf.reshape(q, (-1, h * w, c)) # b, hw, c
k = tf.transpose(k, (0, 3, 1, 2))
k = tf.reshape(k, (-1, c, h * w)) # b, c, hw
y = q @ k
y = y * (c**-0.5)
y = keras.activations.softmax(y)
# Attend to values
v = tf.transpose(v, (0, 3, 1, 2))
v = tf.reshape(v, (-1, c, h * w))
y = tf.transpose(y, (0, 2, 1))
x = v @ y
x = tf.transpose(x, (0, 2, 1))
x = tf.reshape(x, (-1, h, w, c))
return self.proj_out(x) + inputs
class ResnetBlock(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.norm1 = GroupNormalization(epsilon=1e-5)
self.conv1 = PaddedConv2D(output_dim, 3, padding=1)
self.norm2 = GroupNormalization(epsilon=1e-5)
self.conv2 = PaddedConv2D(output_dim, 3, padding=1)
def build(self, input_shape):
if input_shape[-1] != self.output_dim:
self.residual_projection = PaddedConv2D(self.output_dim, 1)
else:
self.residual_projection = lambda x: x
def call(self, inputs):
x = self.conv1(keras.activations.swish(self.norm1(inputs)))
x = self.conv2(keras.activations.swish(self.norm2(x)))
return x + self.residual_projection(inputs)
class ImageEncoder(keras.Sequential):
"""ImageEncoder is the VAE Encoder for StableDiffusion."""
def __init__(self, img_height=512, img_width=512, download_weights=True):
super().__init__(
[
keras.layers.Input((img_height, img_width, 3)),
PaddedConv2D(128, 3, padding=1),
ResnetBlock(128),
ResnetBlock(128),
PaddedConv2D(128, 3, padding=1, strides=2),
ResnetBlock(256),
ResnetBlock(256),
PaddedConv2D(256, 3, padding=1, strides=2),
ResnetBlock(512),
ResnetBlock(512),
PaddedConv2D(512, 3, padding=1, strides=2),
ResnetBlock(512),
ResnetBlock(512),
ResnetBlock(512),
AttentionBlock(512),
ResnetBlock(512),
GroupNormalization(epsilon=1e-5),
keras.layers.Activation("swish"),
PaddedConv2D(8, 3, padding=1),
PaddedConv2D(8, 1),
# TODO(lukewood): can this be refactored to be a Rescaling layer?
# Perhaps some sort of rescale and gather?
# Either way, we may need a lambda to gather the first 4 dimensions.
keras.layers.Lambda(lambda x: x[..., :4] * 0.18215),
]
)
if download_weights:
image_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/vae_encoder.h5",
file_hash="c60fb220a40d090e0f86a6ab4c312d113e115c87c40ff75d11ffcf380aab7ebb",
)
self.load_weights(image_encoder_weights_fpath)
class EndpointHandler():
def __init__(self, path=""):
self.seed = None
img_height = 512
img_width = 512
self.img_height = round(img_height / 128) * 128
self.img_width = round(img_width / 128) * 128
self.MAX_PROMPT_LENGTH = 77
self.diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
diffusion_model_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
)
self.diffusion_model.load_weights(diffusion_model_weights_fpath)
self.image_encoder = ImageEncoder()
def _get_initial_diffusion_noise(self, batch_size, seed):
if seed is not None:
return tf.random.stateless_normal(
(batch_size, self.img_height // 8, self.img_width // 8, 4),
seed=[seed, seed],
)
else:
return tf.random.normal(
(batch_size, self.img_height // 8, self.img_width // 8, 4)
)
def _get_initial_alphas(self, timesteps):
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
return alphas, alphas_prev
def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000):
half = dim // 2
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
return tf.repeat(embedding, batch_size, axis=0)
def _prepare_img_mask(self, image, mask, batch_size):
image = base64.b64decode(image)
image = np.frombuffer(image, dtype="uint8")
image = np.reshape(image, (512, 512, 3))
image = tf.convert_to_tensor(image)
image = tf.squeeze(image)
image = tf.cast(image, dtype=tf.float32) / 255.0 * 2.0 - 1.0
image = tf.expand_dims(image, axis=0)
known_x0 = self.image_encoder(image)
if image.shape.rank == 3:
known_x0 = tf.repeat(known_x0, batch_size, axis=0)
mask = base64.b64decode(mask)
mask = np.frombuffer(mask, dtype="uint8")
mask = np.reshape(mask, (512, 512, 1))
mask = tf.convert_to_tensor(mask)
mask = tf.expand_dims(mask, axis=0)
mask = tf.cast(
tf.nn.max_pool2d(mask, ksize=8, strides=8, padding="SAME"),
dtype=tf.float32,
)
mask = tf.squeeze(mask)
if mask.shape.rank == 2:
mask = tf.repeat(tf.expand_dims(mask, axis=0), batch_size, axis=0)
mask = tf.expand_dims(mask, axis=-1)
return known_x0, mask
def __call__(self, data: Dict[str, Any]) -> str:
# get inputs
inputs = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
context = base64.b64decode(inputs[0])
context = np.frombuffer(context, dtype="float32")
context = np.reshape(context, (batch_size, 77, 768))
unconditional_context = base64.b64decode(inputs[1])
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
num_steps = data.pop("num_steps", 25)
unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
num_resamples = data.pop("num_resamples", 1)
known_x0, mask = self._prepare_img_mask(inputs[2], inputs[3], batch_size)
latent = self._get_initial_diffusion_noise(batch_size, self.seed)
timesteps = tf.range(1, 1000, 1000 // num_steps)
alphas, alphas_prev = self._get_initial_alphas(timesteps)
progbar = keras.utils.Progbar(len(timesteps))
iteration = 0
for index, timestep in list(enumerate(timesteps))[::-1]:
a_t, a_prev = alphas[index], alphas_prev[index]
latent_prev = latent # Set aside the previous latent vector
t_emb = self._get_timestep_embedding(timestep, batch_size)
for resample_index in range(num_resamples):
unconditional_latent = self.diffusion_model.predict_on_batch(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
latent = unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(a_t)
latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
# Use known image (x0) to compute latent
if timestep > 1:
noise = tf.random.normal(tf.shape(known_x0), seed=self.seed)
else:
noise = 0.0
known_latent = (
math.sqrt(a_prev) * known_x0 + math.sqrt(1 - a_prev) * noise
)
# Use known latent in unmasked regions
latent = mask * known_latent + (1 - mask) * latent
# Resample latent
if resample_index < num_resamples - 1 and timestep > 1:
beta_prev = 1 - (a_t / a_prev)
latent_prev = tf.random.normal(
tf.shape(latent),
mean=latent * math.sqrt(1 - beta_prev),
stddev=math.sqrt(beta_prev),
seed=self.seed,
)
iteration += 1
progbar.update(iteration)
latent_b64 = base64.b64encode(latent.numpy().tobytes())
latent_b64str = latent_b64.decode()
return latent_b64str