### IMPORTS import tensorflow as tf import numpy as np import einops import numpy as np import tqdm import collections import re import string import pickle print("import complete") #========================================================================================================================= ### UTILITY FUNCTIONS #========================================================================================================================= IMAGE_SHAPE=(224, 224, 3) @tf.keras.utils.register_keras_serializable() def custom_standardization(s): s = tf.strings.lower(s) s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '') s = tf.strings.join(['[START]', s, '[END]'], separator=' ') return s def load_image(image_path): img = tf.io.read_file(image_path) img = tf.io.decode_jpeg(img, channels=3) img = tf.image.resize(img, IMAGE_SHAPE[:-1]) return img def load_image_obj(img): img = tf.image.resize(img, IMAGE_SHAPE[:-1]) return img def masked_loss(labels, preds): loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds) mask = (labels != 0) & (loss < 1e8) mask = tf.cast(mask, loss.dtype) loss = loss*mask loss = tf.reduce_sum(loss)/tf.reduce_sum(mask) return loss def masked_acc(labels, preds): mask = tf.cast(labels!=0, tf.float32) preds = tf.argmax(preds, axis=-1) labels = tf.cast(labels, tf.int64) match = tf.cast(preds == labels, mask.dtype) acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask) return acc print("utility complete") #========================================================================================================================= ### MODEL CLASS #========================================================================================================================= mobilenet = tf.keras.applications.MobileNetV3Small( input_shape=IMAGE_SHAPE, include_top=False, include_preprocessing=True) mobilenet.trainable=False class SeqEmbedding(tf.keras.layers.Layer): def __init__(self, vocab_size, max_length, depth): super().__init__() self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth) self.token_embedding = tf.keras.layers.Embedding( input_dim=vocab_size, output_dim=depth, mask_zero=True) self.add = tf.keras.layers.Add() def call(self, seq): seq = self.token_embedding(seq) # (batch, seq, depth) x = tf.range(tf.shape(seq)[1]) # (seq) x = x[tf.newaxis, :] # (1, seq) x = self.pos_embedding(x) # (1, seq, depth) return self.add([seq,x]) class CausalSelfAttention(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__() self.mha = tf.keras.layers.MultiHeadAttention(**kwargs) # Use Add instead of + so the keras mask propagates through. self.add = tf.keras.layers.Add() self.layernorm = tf.keras.layers.LayerNormalization() def call(self, x): attn = self.mha(query=x, value=x, use_causal_mask=True) x = self.add([x, attn]) return self.layernorm(x) class CrossAttention(tf.keras.layers.Layer): def __init__(self,**kwargs): super().__init__() self.mha = tf.keras.layers.MultiHeadAttention(**kwargs) self.add = tf.keras.layers.Add() self.layernorm = tf.keras.layers.LayerNormalization() def call(self, x, y, **kwargs): attn, attention_scores = self.mha( query=x, value=y, return_attention_scores=True) self.last_attention_scores = attention_scores x = self.add([x, attn]) return self.layernorm(x) class FeedForward(tf.keras.layers.Layer): def __init__(self, units, dropout_rate=0.1): super().__init__() self.seq = tf.keras.Sequential([ tf.keras.layers.Dense(units=2*units, activation='relu'), tf.keras.layers.Dense(units=units), tf.keras.layers.Dropout(rate=dropout_rate), ]) self.layernorm = tf.keras.layers.LayerNormalization() def call(self, x): x = x + self.seq(x) return self.layernorm(x) class DecoderLayer(tf.keras.layers.Layer): def __init__(self, units, num_heads=1, dropout_rate=0.1): super().__init__() self.self_attention = CausalSelfAttention(num_heads=num_heads, key_dim=units, dropout=dropout_rate) self.cross_attention = CrossAttention(num_heads=num_heads, key_dim=units, dropout=dropout_rate) self.ff = FeedForward(units=units, dropout_rate=dropout_rate) def call(self, inputs, training=False): in_seq, out_seq = inputs # Text input out_seq = self.self_attention(out_seq) out_seq = self.cross_attention(out_seq, in_seq) self.last_attention_scores = self.cross_attention.last_attention_scores out_seq = self.ff(out_seq) return out_seq class TokenOutput(tf.keras.layers.Layer): def __init__(self, tokenizer, banned_tokens=('', '[UNK]', '[START]'), bias=None, **kwargs): super().__init__() self.dense = tf.keras.layers.Dense( units=tokenizer.vocabulary_size(), **kwargs) self.tokenizer = tokenizer self.banned_tokens = banned_tokens self.bias = bias def adapt(self, ds): counts = collections.Counter() vocab_dict = {name: id for id, name in enumerate(self.tokenizer.get_vocabulary())} for tokens in tqdm.tqdm(ds): counts.update(tokens.numpy().flatten()) counts_arr = np.zeros(shape=(self.tokenizer.vocabulary_size(),)) counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values()) counts_arr = counts_arr[:] for token in self.banned_tokens: counts_arr[vocab_dict[token]] = 0 total = counts_arr.sum() p = counts_arr/total p[counts_arr==0] = 1.0 log_p = np.log(p) # log(1) == 0 entropy = -(log_p*p).sum() print() print(f"Uniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}") print(f"Marginal entropy: {entropy:0.2f}") self.bias = log_p self.bias[counts_arr==0] = -1e9 def call(self, x): x = self.dense(x) return x + self.bias def get_config(self): config = super(TokenOutput, self).get_config() config.update({ "tokenizer": self.tokenizer, "banned_tokens": self.banned_tokens, "bias": self.bias, "dense":self.dense }) return config class Captioner(tf.keras.Model): @classmethod def add_method(cls, fun): setattr(cls, fun.__name__, fun) return fun def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1, units=256, max_length=50, num_heads=1, dropout_rate=0.1): super().__init__() self.feature_extractor = feature_extractor self.tokenizer = tokenizer self.word_to_index = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary()) self.index_to_word = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True) self.seq_embedding = SeqEmbedding( vocab_size=tokenizer.vocabulary_size(), depth=units, max_length=max_length) self.decoder_layers = [ DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate) for n in range(num_layers)] self.output_layer = output_layer def call(self, inputs): image, txt = inputs if image.shape[-1] == 3: # Apply the feature-extractor, if you get an RGB image. image = self.feature_extractor(image) # Flatten the feature map image = einops.rearrange(image, 'b h w c -> b (h w) c') if txt.dtype == tf.string: # Apply the tokenizer if you get string inputs. txt = self.tokenizer(txt) txt = self.seq_embedding(txt) # Look at the image for dec_layer in self.decoder_layers: txt = dec_layer(inputs=(image, txt)) txt = self.output_layer(txt) return txt def simple_gen(self, image, temperature=1): initial = self.word_to_index([['[START]']]) # (batch, sequence) img_features = self.feature_extractor(image[tf.newaxis, ...]) tokens = initial # (batch, sequence) for n in range(50): preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab) preds = preds[:,-1, :] #(batch, vocab) if temperature==0: next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1) else: next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1) tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) if next[0] == self.word_to_index('[END]'): break words = self.index_to_word(tokens[0, 1:-1]) result = tf.strings.reduce_join(words, axis=-1, separator=' ') return result.numpy().decode() # def get_config(self): # config = super().get_config() # config.update({"feature_extractor": self.feature_extractor, # "tokenizer": self.tokenizer, # "word_to_index": self.word_to_index, # "index_to_word": self.index_to_word, # "outputlayer": self.output_layer, # "seq_embedding": self.seq_embedding, # "decoder_layers": self.decoder_layers # }) # return config # def build_from_config(self, config): # return super().build_from_config(config) # model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), # loss=masked_loss, # metrics=[masked_acc]) print("model complete") #========================================================================================================================= ### LOAD FUNCTION #========================================================================================================================= def build(): filename = "model/tokenizer.pkl" token_meta = pickle.load(open(filename, 'rb')) tokenizer = tf.keras.layers.TextVectorization.from_config(token_meta["config"]) tokenizer.set_weights(token_meta['weights']) print(tokenizer("bulid sentence")) word_to_index = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary()) index_to_word = tf.keras.layers.StringLookup( mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True) output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]')) filename = "model/output_layer.pkl" bias = pickle.load(open(filename, 'rb')) output_layer.bias = bias load_model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer, units=256, dropout_rate=0.5, num_layers=2, num_heads=2) load_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss=masked_loss, metrics=[masked_acc]) image_url = 'https://tensorflow.org/images/surf.jpg' image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url) image = load_image(image_path) load_model.simple_gen(image) path = "model/captioner_weights" load_model.load_weights(path) return load_model # loaded_model = build() print("loaded") #========================================================================================================================= ### TEST RUN #========================================================================================================================= image_url = 'https://tensorflow.org/images/surf.jpg' image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url) image = load_image(image_path)