chansung commited on
Commit
aae13dc
1 Parent(s): 8e26b1f

add custom handler

Browse files
__pycache__/handler.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
handler.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+
4
+ import math
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow import keras
8
+ from keras_cv.models.generative.stable_diffusion.constants import _ALPHAS_CUMPROD
9
+ from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel
10
+
11
+ class GroupNormalization(tf.keras.layers.Layer):
12
+ """GroupNormalization layer.
13
+ This layer is only here temporarily and will be removed
14
+ as we introduce GroupNormalization in core Keras.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ groups=32,
20
+ axis=-1,
21
+ epsilon=1e-5,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.groups = groups
26
+ self.axis = axis
27
+ self.epsilon = epsilon
28
+
29
+ def build(self, input_shape):
30
+ dim = input_shape[self.axis]
31
+ self.gamma = self.add_weight(
32
+ shape=(dim,),
33
+ name="gamma",
34
+ initializer="ones",
35
+ )
36
+ self.beta = self.add_weight(
37
+ shape=(dim,),
38
+ name="beta",
39
+ initializer="zeros",
40
+ )
41
+
42
+ def call(self, inputs):
43
+ input_shape = tf.shape(inputs)
44
+ reshaped_inputs = self._reshape_into_groups(inputs, input_shape)
45
+ normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
46
+ return tf.reshape(normalized_inputs, input_shape)
47
+
48
+ def _reshape_into_groups(self, inputs, input_shape):
49
+ group_shape = [input_shape[i] for i in range(inputs.shape.rank)]
50
+ group_shape[self.axis] = input_shape[self.axis] // self.groups
51
+ group_shape.insert(self.axis, self.groups)
52
+ group_shape = tf.stack(group_shape)
53
+ return tf.reshape(inputs, group_shape)
54
+
55
+ def _apply_normalization(self, reshaped_inputs, input_shape):
56
+ group_reduction_axes = list(range(1, reshaped_inputs.shape.rank))
57
+ axis = -2 if self.axis == -1 else self.axis - 1
58
+ group_reduction_axes.pop(axis)
59
+ mean, variance = tf.nn.moments(
60
+ reshaped_inputs, group_reduction_axes, keepdims=True
61
+ )
62
+ gamma, beta = self._get_reshaped_weights(input_shape)
63
+ return tf.nn.batch_normalization(
64
+ reshaped_inputs,
65
+ mean=mean,
66
+ variance=variance,
67
+ scale=gamma,
68
+ offset=beta,
69
+ variance_epsilon=self.epsilon,
70
+ )
71
+
72
+ def _get_reshaped_weights(self, input_shape):
73
+ broadcast_shape = self._create_broadcast_shape(input_shape)
74
+ gamma = tf.reshape(self.gamma, broadcast_shape)
75
+ beta = tf.reshape(self.beta, broadcast_shape)
76
+ return gamma, beta
77
+
78
+ def _create_broadcast_shape(self, input_shape):
79
+ broadcast_shape = [1] * input_shape.shape.rank
80
+ broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
81
+ broadcast_shape.insert(self.axis, self.groups)
82
+ return broadcast_shape
83
+
84
+ class PaddedConv2D(keras.layers.Layer):
85
+ def __init__(self, filters, kernel_size, padding=0, strides=1, **kwargs):
86
+ super().__init__(**kwargs)
87
+ self.padding2d = keras.layers.ZeroPadding2D(padding)
88
+ self.conv2d = keras.layers.Conv2D(filters, kernel_size, strides=strides)
89
+
90
+ def call(self, inputs):
91
+ x = self.padding2d(inputs)
92
+ return self.conv2d(x)
93
+
94
+ class AttentionBlock(keras.layers.Layer):
95
+ def __init__(self, output_dim, **kwargs):
96
+ super().__init__(**kwargs)
97
+ self.output_dim = output_dim
98
+ self.norm = GroupNormalization(epsilon=1e-5)
99
+ self.q = PaddedConv2D(output_dim, 1)
100
+ self.k = PaddedConv2D(output_dim, 1)
101
+ self.v = PaddedConv2D(output_dim, 1)
102
+ self.proj_out = PaddedConv2D(output_dim, 1)
103
+
104
+ def call(self, inputs):
105
+ x = self.norm(inputs)
106
+ q, k, v = self.q(x), self.k(x), self.v(x)
107
+
108
+ # Compute attention
109
+ _, h, w, c = q.shape
110
+ q = tf.reshape(q, (-1, h * w, c)) # b, hw, c
111
+ k = tf.transpose(k, (0, 3, 1, 2))
112
+ k = tf.reshape(k, (-1, c, h * w)) # b, c, hw
113
+ y = q @ k
114
+ y = y * (c**-0.5)
115
+ y = keras.activations.softmax(y)
116
+
117
+ # Attend to values
118
+ v = tf.transpose(v, (0, 3, 1, 2))
119
+ v = tf.reshape(v, (-1, c, h * w))
120
+ y = tf.transpose(y, (0, 2, 1))
121
+ x = v @ y
122
+ x = tf.transpose(x, (0, 2, 1))
123
+ x = tf.reshape(x, (-1, h, w, c))
124
+ return self.proj_out(x) + inputs
125
+
126
+ class ResnetBlock(keras.layers.Layer):
127
+ def __init__(self, output_dim, **kwargs):
128
+ super().__init__(**kwargs)
129
+ self.output_dim = output_dim
130
+ self.norm1 = GroupNormalization(epsilon=1e-5)
131
+ self.conv1 = PaddedConv2D(output_dim, 3, padding=1)
132
+ self.norm2 = GroupNormalization(epsilon=1e-5)
133
+ self.conv2 = PaddedConv2D(output_dim, 3, padding=1)
134
+
135
+ def build(self, input_shape):
136
+ if input_shape[-1] != self.output_dim:
137
+ self.residual_projection = PaddedConv2D(self.output_dim, 1)
138
+ else:
139
+ self.residual_projection = lambda x: x
140
+
141
+ def call(self, inputs):
142
+ x = self.conv1(keras.activations.swish(self.norm1(inputs)))
143
+ x = self.conv2(keras.activations.swish(self.norm2(x)))
144
+ return x + self.residual_projection(inputs)
145
+
146
+ class ImageEncoder(keras.Sequential):
147
+ """ImageEncoder is the VAE Encoder for StableDiffusion."""
148
+
149
+ def __init__(self, img_height=512, img_width=512, download_weights=True):
150
+ super().__init__(
151
+ [
152
+ keras.layers.Input((img_height, img_width, 3)),
153
+ PaddedConv2D(128, 3, padding=1),
154
+ ResnetBlock(128),
155
+ ResnetBlock(128),
156
+ PaddedConv2D(128, 3, padding=1, strides=2),
157
+ ResnetBlock(256),
158
+ ResnetBlock(256),
159
+ PaddedConv2D(256, 3, padding=1, strides=2),
160
+ ResnetBlock(512),
161
+ ResnetBlock(512),
162
+ PaddedConv2D(512, 3, padding=1, strides=2),
163
+ ResnetBlock(512),
164
+ ResnetBlock(512),
165
+ ResnetBlock(512),
166
+ AttentionBlock(512),
167
+ ResnetBlock(512),
168
+ GroupNormalization(epsilon=1e-5),
169
+ keras.layers.Activation("swish"),
170
+ PaddedConv2D(8, 3, padding=1),
171
+ PaddedConv2D(8, 1),
172
+ # TODO(lukewood): can this be refactored to be a Rescaling layer?
173
+ # Perhaps some sort of rescale and gather?
174
+ # Either way, we may need a lambda to gather the first 4 dimensions.
175
+ keras.layers.Lambda(lambda x: x[..., :4] * 0.18215),
176
+ ]
177
+ )
178
+
179
+ if download_weights:
180
+ image_encoder_weights_fpath = keras.utils.get_file(
181
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/vae_encoder.h5",
182
+ file_hash="c60fb220a40d090e0f86a6ab4c312d113e115c87c40ff75d11ffcf380aab7ebb",
183
+ )
184
+ self.load_weights(image_encoder_weights_fpath)
185
+
186
+ class EndpointHandler():
187
+ def __init__(self, path=""):
188
+ self.seed = None
189
+
190
+ img_height = 512
191
+ img_width = 512
192
+ self.img_height = round(img_height / 128) * 128
193
+ self.img_width = round(img_width / 128) * 128
194
+
195
+ self.MAX_PROMPT_LENGTH = 77
196
+ self.diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
197
+ diffusion_model_weights_fpath = keras.utils.get_file(
198
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
199
+ file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
200
+ )
201
+ self.diffusion_model.load_weights(diffusion_model_weights_fpath)
202
+
203
+ self.image_encoder = ImageEncoder()
204
+
205
+ def _get_initial_diffusion_noise(self, batch_size, seed):
206
+ if seed is not None:
207
+ return tf.random.stateless_normal(
208
+ (batch_size, self.img_height // 8, self.img_width // 8, 4),
209
+ seed=[seed, seed],
210
+ )
211
+ else:
212
+ return tf.random.normal(
213
+ (batch_size, self.img_height // 8, self.img_width // 8, 4)
214
+ )
215
+
216
+ def _get_initial_alphas(self, timesteps):
217
+ alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
218
+ alphas_prev = [1.0] + alphas[:-1]
219
+
220
+ return alphas, alphas_prev
221
+
222
+ def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000):
223
+ half = dim // 2
224
+ freqs = tf.math.exp(
225
+ -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
226
+ )
227
+ args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
228
+ embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
229
+ embedding = tf.reshape(embedding, [1, -1])
230
+ return tf.repeat(embedding, batch_size, axis=0)
231
+
232
+ def _prepare_img_mask(self, image, mask, batch_size):
233
+ image = base64.b64decode(image)
234
+ image = np.frombuffer(image, dtype="uint8")
235
+ image = np.reshape(image, (512, 512, 3))
236
+ image = tf.convert_to_tensor(image)
237
+
238
+ image = tf.squeeze(image)
239
+ image = tf.cast(image, dtype=tf.float32) / 255.0 * 2.0 - 1.0
240
+ image = tf.expand_dims(image, axis=0)
241
+ known_x0 = self.image_encoder(image)
242
+ if image.shape.rank == 3:
243
+ known_x0 = tf.repeat(known_x0, batch_size, axis=0)
244
+
245
+ mask = base64.b64decode(mask)
246
+ mask = np.frombuffer(mask, dtype="uint8")
247
+ mask = np.reshape(mask, (512, 512, 1))
248
+ mask = tf.convert_to_tensor(mask)
249
+
250
+ mask = tf.expand_dims(mask, axis=0)
251
+ mask = tf.cast(
252
+ tf.nn.max_pool2d(mask, ksize=8, strides=8, padding="SAME"),
253
+ dtype=tf.float32,
254
+ )
255
+ mask = tf.squeeze(mask)
256
+ if mask.shape.rank == 2:
257
+ mask = tf.repeat(tf.expand_dims(mask, axis=0), batch_size, axis=0)
258
+ mask = tf.expand_dims(mask, axis=-1)
259
+
260
+ return known_x0, mask
261
+
262
+ def __call__(self, data: Dict[str, Any]) -> str:
263
+ # get inputs
264
+ inputs = data.pop("inputs", data)
265
+ batch_size = data.pop("batch_size", 1)
266
+
267
+ context = base64.b64decode(inputs[0])
268
+ context = np.frombuffer(context, dtype="float32")
269
+ context = np.reshape(context, (batch_size, 77, 768))
270
+
271
+ unconditional_context = base64.b64decode(inputs[1])
272
+ unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
273
+ unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
274
+
275
+ num_steps = data.pop("num_steps", 25)
276
+ unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
277
+ num_resamples = data.pop("num_resamples", 1)
278
+
279
+ known_x0, mask = self._prepare_img_mask(inputs[2], inputs[3], batch_size)
280
+
281
+ latent = self._get_initial_diffusion_noise(batch_size, self.seed)
282
+
283
+ timesteps = tf.range(1, 1000, 1000 // num_steps)
284
+ alphas, alphas_prev = self._get_initial_alphas(timesteps)
285
+
286
+ progbar = keras.utils.Progbar(len(timesteps))
287
+ iteration = 0
288
+
289
+ for index, timestep in list(enumerate(timesteps))[::-1]:
290
+ a_t, a_prev = alphas[index], alphas_prev[index]
291
+ latent_prev = latent # Set aside the previous latent vector
292
+ t_emb = self._get_timestep_embedding(timestep, batch_size)
293
+
294
+ for resample_index in range(num_resamples):
295
+ unconditional_latent = self.diffusion_model.predict_on_batch(
296
+ [latent, t_emb, unconditional_context]
297
+ )
298
+ latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
299
+ latent = unconditional_latent + unconditional_guidance_scale * (
300
+ latent - unconditional_latent
301
+ )
302
+ pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(a_t)
303
+ latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
304
+
305
+ # Use known image (x0) to compute latent
306
+ if timestep > 1:
307
+ noise = tf.random.normal(tf.shape(known_x0), seed=self.seed)
308
+ else:
309
+ noise = 0.0
310
+ known_latent = (
311
+ math.sqrt(a_prev) * known_x0 + math.sqrt(1 - a_prev) * noise
312
+ )
313
+ # Use known latent in unmasked regions
314
+ latent = mask * known_latent + (1 - mask) * latent
315
+ # Resample latent
316
+ if resample_index < num_resamples - 1 and timestep > 1:
317
+ beta_prev = 1 - (a_t / a_prev)
318
+ latent_prev = tf.random.normal(
319
+ tf.shape(latent),
320
+ mean=latent * math.sqrt(1 - beta_prev),
321
+ stddev=math.sqrt(beta_prev),
322
+ seed=self.seed,
323
+ )
324
+
325
+ iteration += 1
326
+ progbar.update(iteration)
327
+
328
+ latent_b64 = base64.b64encode(latent.numpy().tobytes())
329
+ latent_b64str = latent_b64.decode()
330
+
331
+ return latent_b64str
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ keras-cv
2
+ tensorflow
3
+ tensorflow_datasets