diffusion-med-coco / handler.py
detectivejoewest's picture
Update handler.py
70f25a9 verified
from transformers import PreTrainedTokenizerFast
token2vec = PreTrainedTokenizerFast.from_pretrained("/repository/bpe")
from typing import Dict, Any
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, LayerNormalization, Conv2D, UpSampling2D, Embedding, MultiHeadAttention
from tensorflow.keras.saving import register_keras_serializable
import tensorflow as tf
# @title Config
def small_config():
T = 500
beta = np.linspace(1e-4, 0.02, T)
alpha = 1 - beta
a = np.cumprod(alpha)
return {
"filters": [128, 256],
"hidden_dim": 384,
"heads": 6,
"layers": 8,
"patch_size": 4,
"batch_size": 64,
"T": T,
"context_size": 8,
"image_size": 128,
"latent_shape": (32, 32, 4),
"beta": beta,
"alpha": alpha,
"a": a}
def med_config():
T = 1000
beta = np.linspace(1e-4, 0.02, T)
alpha = 1 - beta
a = np.cumprod(alpha)
return {
"filters": [128, 256],
"hidden_dim": 768,
"heads": 12,
"layers": 12,
"patch_size": 4,
"batch_size": 64,
"T": T,
"context_size": 8,
"image_size": 128,
"latent_shape": (32, 32, 4),
"beta": beta,
"alpha": alpha,
"a": a}
def large_config():
T = 1000
beta = np.linspace(1e-4, 0.02, T)
alpha = 1 - beta
a = np.cumprod(alpha)
return {
"filters": [128, 256],
"hidden_dim": 1024,
"heads": 16,
"layers": 24,
"patch_size": 4,
"batch_size": 64,
"T": T,
"context_size": 8,
"image_size": 128,
"latent_shape": (32, 32, 4),
"beta": beta,
"alpha": alpha,
"a": a}
config = med_config()
filters = config['filters']
hidden_dim = config['hidden_dim']
heads = config['heads']
layers = config['layers']
patch_size = config['patch_size']
batch_size = config['batch_size']
T = config['T']
context_size = config['context_size']
image_size = config['image_size']
latent_shape = config['latent_shape']
beta = config['beta']
alpha = config['alpha']
a = config['a']
# @title ResBlock, UpBlock, DownBlock
@register_keras_serializable()
class ResBlock(tf.keras.layers.Layer):
def __init__(self, filters, p, **kwargs):
super(ResBlock, self).__init__(**kwargs)
self.filters = filters
self.p = p
self.reshape = Conv2D(filters, kernel_size=1, strides=1, padding="same")
#self.norm = BatchNormalization(center=False, scale=False)
self.conv1 = Conv2D(filters, kernel_size=p, strides=1, padding="same", activation="swish")
self.conv2 = Conv2D(filters, kernel_size=p, strides=1, padding="same")
def call(self, x):
x = self.reshape(x)
resid = x
#resid = self.norm(resid)
resid = self.conv1(resid)
resid = self.conv2(resid)
x = x + resid
return x
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters,
"p": self.p})
return config
@register_keras_serializable()
class DownBlock(tf.keras.layers.Layer):
def __init__(self, filters, **kwargs):
super(DownBlock, self).__init__(**kwargs)
self.filters = filters
self.resBlocks = [ResBlock(f, p=3) for f in filters]
self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2))
def call(self, x):
for resBlock in self.resBlocks:
x = resBlock(x)
x = self.pool(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters})
return config
@register_keras_serializable()
class UpBlock(tf.keras.layers.Layer):
def __init__(self, filters, **kwargs):
super(UpBlock, self).__init__(**kwargs)
self.filters = filters
self.resBlocks = [ResBlock(f, p=3) for f in filters]
self.upSample = UpSampling2D(size=2, interpolation="bilinear")
def call(self, x):
x = self.upSample(x)
for resBlock in self.resBlocks:
x = resBlock(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters})
return config
# @title Encoder, Decoder
@register_keras_serializable()
class Encoder(tf.keras.Model):
def __init__(self, filters, latent_dim, **kwargs):
super(Encoder, self).__init__(**kwargs)
self.filters = filters
self.latent_dim = latent_dim
self.downBlocks = [DownBlock([f,f]) for f in filters]
self.latent_proj = Conv2D(latent_dim * 2, kernel_size=1, strides=1, padding="same", activation="linear")
@tf.function
def sample(self, mu, logvar):
eps = tf.random.normal(shape=tf.shape(mu))
return eps * tf.exp(logvar * .5) + mu
def call(self, x, training=1):
for downBlock in self.downBlocks:
x = downBlock(x)
x = self.latent_proj(x)
mu, logvar = tf.split(x, 2, axis=-1)
z = self.sample(mu, logvar)
return z, mu, logvar
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters,
"latent_dim": self.latent_dim})
return config
def compute_output_shape(self, input_shape):
return (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim), (input_shape[0], self.latent_dim)
@register_keras_serializable()
class Decoder(tf.keras.Model):
def __init__(self, filters, img_size, **kwargs):
super(Decoder, self).__init__(**kwargs)
self.filters = filters[::-1]
self.img_size = img_size
self.undo_latent_proj = Conv2D(filters[0], kernel_size=1, strides=1, padding="same")
self.upBlocks = [UpBlock([f,f]) for f in filters]
self.conv_proj = Conv2D(3, kernel_size=3, padding="same", activation="linear")
def call(self, z, training=1):
z = self.undo_latent_proj(z)
for upBlock in self.upBlocks:
z = upBlock(z)
x = self.conv_proj(z)
return x
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters[::-1],
"img_size": self.img_size})
return config
def compute_output_shape(self, input_shape):
return (input_shape[0], self.img_size, self.img_size, 3)
# @title Helper Functions
def process_text(text):
tokens = token2vec.encode(text)
while len(tokens) < context_size:
tokens.append(0)
return np.array(tokens[0:context_size])
def normalise_img(img_tensor): # Maps [-1,1] to [0,1]
img = img_tensor
img *= 0.5
img += 0.5
return img
def prep_img(img_tensor): # Maps [0,255] to [-1,1]
img = img_tensor.copy()
img = img / 127.5
img -= 1
return img
def noisify_img(img_tensor, t, a): # Returns x_t and the noise used
epsilon = np.random.normal(0, 1, img_tensor.shape).astype(np.float32) # Standard normal
sqrt_alpha_bar = np.sqrt(a[t])
sqrt_one_minus_alpha_bar = np.sqrt(1 - a[t])
x_t = sqrt_alpha_bar * img_tensor + sqrt_one_minus_alpha_bar * epsilon
return x_t, epsilon
def denoise_step(x_t, eps_hat, t, a, beta):
"""
Reverse one DDPM step: x_t β†’ x_{t-1}
"""
a_bar_t = tf.convert_to_tensor(a[t], dtype=tf.float32)
a_bar_prev = tf.convert_to_tensor(a[t - 1] if t > 0 else 1.0, dtype=tf.float32)
a_t = a_bar_t / a_bar_prev
beta_t = tf.convert_to_tensor(beta[t], dtype=tf.float32)
# Avoid NaNs with clamping
sqrt_recip_a_t = tf.math.rsqrt(tf.maximum(a_t, 1e-5))
sqrt_one_minus_ab = tf.sqrt(tf.maximum(1. - a_bar_t, 1e-5))
eps_term = (beta_t / sqrt_one_minus_ab) * eps_hat
mean = sqrt_recip_a_t * (x_t - eps_term)
if t > 1:
noise = tf.random.normal(shape=x_t.shape)
sigma = tf.sqrt(tf.maximum(beta_t, 1e-5))
x_prev = mean + sigma * noise
else:
x_prev = mean
return x_prev
# @title Transformer Block
@register_keras_serializable()
class TransformerBlock(tf.keras.Layer):
def __init__(self, context_size, head_no, latent_dim, **kwargs):
super().__init__(**kwargs)
self.context_size = context_size
self.head_no = head_no
self.latent_dim = latent_dim
self.attn = MultiHeadAttention(num_heads=head_no, key_dim=latent_dim//head_no, output_shape=latent_dim)
self.mlp_up = Dense(latent_dim*4, activation="gelu")
self.mlp_down = Dense(latent_dim)
self.norm1 = LayerNormalization()
self.norm2 = LayerNormalization()
def call(self, x):
normed = self.norm1(x)
x = x + self.attn(normed, normed, normed)
normed = self.norm2(x)
dx = self.mlp_up(normed)
x = x + self.mlp_down(dx)
return x
def build(self, input_shape):
super().build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super().get_config()
config.update({
"context_size": self.context_size,
"head_no": self.head_no,
"latent_dim": self.latent_dim})
return config
# @title AdaLN-Zero
@register_keras_serializable()
class AdaptiveLayerNorm(tf.keras.Layer):
def __init__(self, eps=1e-6,**kwargs):
self.layernorm = LayerNormalization(epsilon=eps,center=False, scale=False)
super(AdaptiveLayerNorm, self).__init__(**kwargs)
def build(self, input_shape):
#B, num_patches, hidden_dim
self.M = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
self.b = Dense(input_shape[2], use_bias=True, kernel_initializer='glorot_uniform', activation="linear")
def call(self, x, cond):
gamma = self.M(cond)
beta = self.b(cond)
x = self.layernorm(x)
x = x * (1 + tf.expand_dims(gamma, 1)) + tf.expand_dims(beta, 1)
return x
def get_config(self):
config = super().get_config()
return config
# @title Image Embedder, Unembedder
@register_keras_serializable()
class ImageEmbedder(tf.keras.Layer):
def __init__(self, latent_size, patch_size, emb_dim,**kwargs):
super().__init__(**kwargs)
self.emb_dim = emb_dim
self.patch_size = patch_size
self.latent_size = latent_size
self.pos_emb = Embedding(input_dim=(latent_size // patch_size)**2 , output_dim=emb_dim, embeddings_initializer="glorot_uniform")
self.reshaper = Dense(emb_dim, kernel_initializer="glorot_uniform")
self.conv_expansion = Conv2D(emb_dim, kernel_size=patch_size, strides=patch_size, padding="same")
def call(self, x):
x = self.reshaper(x)
x = self.conv_expansion(x)
x = tf.reshape(x, shape=[tf.shape(x)[0], tf.shape(x)[1]*tf.shape(x)[2], tf.shape(x)[3]])
positions = tf.range(start=0, limit=(self.latent_size // self.patch_size)**2, delta=1)
embeddings = self.pos_emb(positions)
x = embeddings + x
return x
def get_config(self):
config = super().get_config()
config.update({
"latent_size" : self.latent_size,
"patch_size": self.patch_size,
"emb_dim": self.emb_dim})
return config
@register_keras_serializable()
class ImageUnembedder(tf.keras.Layer):
def __init__(self, latent_size, patch_size, latent_dim, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.patch_size = patch_size
self.latent_size = latent_size
self.AdaLN = AdaptiveLayerNorm()
self.reshape_to_latent = Dense(patch_size*patch_size*latent_dim, kernel_initializer="glorot_uniform")
def call(self, x, cond):
x = self.AdaLN(x, cond)
x = self.reshape_to_latent(x)
x = tf.reshape(x, shape=
[tf.shape(x)[0],
self.latent_size // self.patch_size,
self.latent_size // self.patch_size,
self.latent_dim*(self.patch_size**2)])
x = tf.nn.depth_to_space(x, block_size=self.patch_size)
return x
def get_config(self):
config = super().get_config()
config.update({
"latent_size" : self.latent_size,
"patch_size": self.patch_size,
"latent_dim": self.latent_dim})
return config
# @title LEGACY Prompt and Timestep Embedder
@register_keras_serializable()
class ConditioningEmbedder(tf.keras.layers.Layer):
def __init__(self, emb_dim, T, context_size, vocab_size=100266, **kwargs):
super().__init__(**kwargs)
self.emb_dim = emb_dim
self.T = T
self.context_size = context_size
self.vocab_size = vocab_size
positions = tf.range(T, dtype=tf.float32)[:, tf.newaxis]
frequencies = tf.constant(10000 ** (-tf.range(0, emb_dim, 2, dtype=tf.float32) / emb_dim))
angle_rates = positions * frequencies # (T, emb_dim/2)
sin_part = tf.sin(angle_rates)
cos_part = tf.cos(angle_rates)
emb = tf.stack([sin_part, cos_part], axis=-1) # (T, emb_dim/2, 2)
emb = tf.reshape(emb, [T, emb_dim]) # (T, emb_dim)
self.t_embeddings = tf.constant(emb, dtype=tf.float32)
self.prompt_emb = self.add_weight(shape=(vocab_size, emb_dim), initializer='glorot_uniform', name='prompt_emb', trainable=True)
self.CLS = self.add_weight(shape=(emb_dim,), initializer='glorot_uniform', name='CLS', trainable=True)
self.prompt_pos_enc = self.add_weight(shape=(1, context_size+1, emb_dim), initializer='glorot_uniform', name='prompt_pos_enc', trainable=True)
self.transformer = TransformerBlock(context_size+1, head_no=6, latent_dim=emb_dim)
def call(self, x):
t, prompt_tokens = x
# ── timestep embedding ───────────────────────────
t = tf.cast(tf.squeeze(t, axis=-1), tf.int32) # (batch,)
embedded_t = tf.gather(self.t_embeddings, t) # (batch, emb_dim)
embedded_t = embedded_t[:, tf.newaxis, :] # (batch, 1, emb_dim)
# ── prompt embedding path ─────────────────────────
embedded_prompt = tf.nn.embedding_lookup(
self.prompt_emb, prompt_tokens) # (batch, seq_len, emb_dim)
cls_tok = tf.tile(self.CLS[None, None, :],
[tf.shape(embedded_prompt)[0], 1, 1])
embedded_prompt = tf.concat([cls_tok, embedded_prompt], axis=1)
embedded_prompt += self.prompt_pos_enc
embedded_prompt = self.transformer(embedded_prompt) # (batch, seq_len+1, emb_dim)
# add t-embedding to every token (broadcasts along axis-1)
embedded_prompt += embedded_t
# return CLS (keep singleton axis if you need it)
return embedded_prompt[:, 0, :] # (batch, 1, emb_dim)
def get_config(self):
config = super().get_config()
config.update({
"emb_dim": self.emb_dim,
"T": self.T,
"context_size": self.context_size,
"vocab_size": self.vocab_size})
return config
# @title DiT Block
class Gain(tf.keras.layers.Layer):
def __init__(self):
super(Gain, self).__init__()
def build(self, input_shape):
self.M = Dense(input_shape[2], use_bias=True,kernel_initializer='glorot_uniform')
def call(self, x, cond):
scale = self.M(cond)
x *= tf.expand_dims(scale, 1)
return x
@register_keras_serializable()
class DiTBlock(tf.keras.layers.Layer):
def __init__(self, hidden_dim, heads, context_size, **kwargs):
super().__init__(**kwargs)
self.emb_dim = hidden_dim
self.heads = heads
self.context_size = context_size
self.gain1 = Gain()
self.gain2 = Gain()
self.adaLN1 = AdaptiveLayerNorm()
self.attn = MultiHeadAttention(num_heads=self.heads, key_dim=self.emb_dim//self.heads, output_shape=self.emb_dim)
self.adaLN2 = AdaptiveLayerNorm()
self.mlp_up = Dense(self.emb_dim*4, activation="gelu")
self.mlp_down = Dense(self.emb_dim)
def call(self, x, cond):
R = self.adaLN1(x, cond)
R = self.gain1(self.attn(R, R, R), cond)
x = x + R
R = self.adaLN2(x, cond)
R = self.mlp_up(R)
R = self.gain2(self.mlp_down(R), cond)
x = x + R
return x
def get_config(self):
config = super().get_config()
config.update({"hidden_dim": self.emb_dim,
"heads": self.heads,
"context_size": self.context_size})
return config
encoder = tf.keras.models.load_model("/repository/encoder.keras")
decoder = tf.keras.models.load_model("/repository/decoder.keras")
diffuser = tf.keras.models.load_model("/repository/diffusion-med-coco.keras")
def inference(prompts):
N = len(prompts)
x_t = tf.random.normal(shape=(N, 32, 32, 4))
texts = tf.convert_to_tensor([process_text(p) for p in prompts])
t_shape = (N, 1)
for t in reversed(range(T)):
t_batch = tf.convert_to_tensor([[t]] * N)
eps_hat = diffuser([x_t, texts, t_batch])
x_t = tf.convert_to_tensor(denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32)
x_0 = x_t.numpy()
imgs = decoder(x_0)
return imgs
class EndpointHandler:
def __init__(self, path="."):
pass # models already loaded above
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
prompts = inputs["inputs"]
N = len(prompts)
x_t = tf.random.normal(shape=(N, *latent_shape))
texts = tf.convert_to_tensor([process_text(p) for p in prompts])
for t in reversed(range(T)):
t_batch = tf.convert_to_tensor([[t]] * N)
eps_hat = diffuser([x_t, texts, t_batch])
x_t = tf.convert_to_tensor(
denoise_step(x_t.numpy(), eps_hat.numpy(), t, a, beta), dtype=tf.float32
)
imgs = decoder(x_t)
return {"outputs": imgs.numpy().tolist()}