from transformers import TFPreTrainedModel, PreTrainedTokenizer from tensorflow.keras.models import Model, load_model, Sequential from tensorflow.keras.layers import Layer, Dense, concatenate, Input, add, Dropout, LayerNormalization, MultiHeadAttention, Embedding import tensorflow as tf import numpy as np from typing import Dict import re import unicodedata from configuration_bilma import BilmaConfig # copied from preprocessing.py BLANK = ' ' RE_OPS = re.I | re.M | re.S RE_USR = re.compile(r"""@\S+""", RE_OPS) RE_TAG = re.compile(r"""#\S+""", RE_OPS) RE_URL = re.compile(r"""(http|ftp|https)://\S+""", RE_OPS) RE_NUM = re.compile(r"""[-+]?\d+\.?\d*""", RE_OPS) SYMBOLS_ = "()[]¿?¡!{}~<>|" SYMBOLS = set(";:,.@\\-\"/" + SYMBOLS_) # ------------------ # Class declaration # ------------------ class TFBilma(TFPreTrainedModel): config_class = BilmaConfig main_input_name = "input_ids" #base_model_prefix = "bilma" def __init__(self, config): self.seq_max_length = config.seq_max_length self.include_top = config.include_top super().__init__(config) self.model = bilma(num_enc=config.num_hidden_layers, embed_dim=config.hidden_size, max_length=config.seq_max_length, num_heads=config.num_attention_heads, ff_dim=config.hidden_size, vocab_size=config.vocab_size, rate=config.hidden_dropout_prob, include_top = config.include_top) @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: dummies = {} for key, spec in self.input_signature.items(): dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] if spec.shape[0] is None: dummy_shape[0] = 1 dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype) return dummies @property def input_signature(self) -> Dict[str, tf.TensorSpec]: sig = {} sig["input_ids"] = tf.TensorSpec([None, self.seq_max_length], tf.int32, name="input_ids") return sig def call(self, inputs): ins = tf.cast(inputs["input_ids"], tf.float32) if self.include_top: output = {"logits":self.model(ins)} else: output = {"last_hidden_state":self.model(ins)} return output # copied from bilma_model.py # -------------------------- def loss_function(ignore_id=0): loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none') def loss(real, pred): mask = tf.math.logical_not(tf.math.equal(real, ignore_id)) loss_ = loss_object(real, pred) mask = tf.cast(mask, dtype=loss_.dtype) loss_ *= mask sum_ = tf.reduce_sum(mask,axis=1) loss_ = tf.math.divide_no_nan(tf.reduce_sum(loss_, axis=1), sum_) return loss_ return loss def accuracy_function(ignore_id=0): def acc_mlm(real, pred): accuracies = tf.equal(tf.cast(real, tf.int64), tf.argmax(pred, axis=2)) mask = tf.math.logical_not(tf.math.equal(real, ignore_id)) accuracies = tf.math.logical_and(mask, accuracies) accuracies = tf.cast(accuracies, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32) return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask)) return acc_mlm def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True): capt_inputs_ids = Input(shape=(max_length, ), name='input_ids') capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding") capt_inputs = capt_embedding(capt_inputs_ids) enc = Encoder(num_enc, embed_dim, max_length, num_heads, ff_dim, rate=rate, name="bilma/encoder") enc_output = enc(capt_inputs) if include_top: fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output) else: fin_output = enc_output caption_model = Model(inputs=capt_inputs_ids, outputs=[fin_output], name="bilma_model") return caption_model def load(model_file): custom_objects={"EncoderBlock": EncoderBlock, "Encoder": Encoder, "loss": loss_function(), "acc_mlm":accuracy_function(), } return load_model(model_file, custom_objects=custom_objects) # # Copied from transformer_text.py # ------------------------------- class EncoderBlock(Layer): def __init__(self, layer_num, patch_dim, num_heads, ff_dim, rate=0.1, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.ln = layer_num self.p_d = patch_dim self.n_h = num_heads self.f_d = ff_dim self.rate = rate self.att = MultiHeadAttention(num_heads=num_heads, key_dim=patch_dim, name=f"bilma/MHA_{layer_num}") self.ffn = Sequential( #[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu), # Conv1D(patch_dim, kernel_size=1),] [Dense(ff_dim, activation=tf.nn.gelu, name=f"bilma/dense1_{layer_num}"), Dense(patch_dim, name=f"bilma/dense2_{layer_num}")] ) #self.layernorm0 = LayerNormalization(epsilon=1e-6) self.layernorm1 = LayerNormalization(epsilon=1e-6, name=f"ln1_{layer_num}") self.layernorm2 = LayerNormalization(epsilon=1e-6, name=f"ln2_{layer_num}") self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) def get_config(self): config = super(EncoderBlock, self).get_config() config.update({"layer_num":self.ln, "patch_dim":self.p_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, inputs, training=False): #inputs = self.layernorm0(inputs) attn_output = self.att(inputs, inputs) attn_output = self.dropout1(attn_output, training=training) out1 = self.layernorm1(add([inputs, attn_output])) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) return self.layernorm2(add([out1, ffn_output])) class DecoderBlock(Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.e_d = embed_dim self.n_h = num_heads self.f_d = ff_dim self.rate = rate self.att1 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.att2 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.ffn = Sequential( #[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu), # Conv1D(embed_dim, kernel_size=1),] [Dense(ff_dim, activation=tf.nn.gelu), Dense(embed_dim),] ) self.layernorm1 = LayerNormalization(epsilon=1e-6) self.layernorm2 = LayerNormalization(epsilon=1e-6) self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) self.dropout3 = Dropout(rate) def get_config(self): config = super(DecoderBlock, self).get_config() config.update({"embed_dim":self.e_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, inputs, encoder_output, look_ahead_mask, padding_mask, training=None): y, attn_output1 = self.att1(inputs, inputs, attention_mask=look_ahead_mask, return_attention_scores=True) y = self.dropout1(y, training=training) y = add([inputs, y]) out1 = self.layernorm1(y) y, attn_encoder = self.att2(out1, encoder_output, attention_mask=padding_mask, return_attention_scores=True) y = self.dropout2(y, training=training) y = add([out1, y]) out2 = self.layernorm1(y) ffn_output = self.ffn(out2) ffn_output = self.dropout3(ffn_output, training=training) final_output = self.layernorm2(out2 + ffn_output) return final_output, attn_output1, attn_encoder class Encoder(Layer): def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs): super(Encoder, self).__init__(**kwargs) self.n = n self.embed_dim = embed_dim self.max_length = max_length self.n_h = num_heads self.f_d = ff_dim self.rate = rate self._layers = [EncoderBlock(i, embed_dim, num_heads, ff_dim, rate=0.1, name=f"enc_block_{i}") for i in range(n)] self.pe = positional_encoding(self.max_length, self.embed_dim) def get_config(self): config = super(Encoder, self).get_config() config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, x, training=False): x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32)) x = x + self.pe[:, :tf.shape(x)[1], :] for layer in self._layers: x = layer(x, training) return x class Decoder(Layer): def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs): super(Decoder, self).__init__(**kwargs) self.n = n self.embed_dim = embed_dim self.max_length = max_length self.n_h = num_heads self.f_d = ff_dim self.rate = rate self._layers = [DecoderBlock(embed_dim, num_heads, ff_dim, rate=0.1) for _ in range(n)] self.pe = positional_encoding(self.max_length, self.embed_dim) def get_config(self): config = super(Decoder, self).get_config() config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, x, encoder_output, look_ahead_mask, padding_mask, training): x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32)) x = x + self.pe[:, :tf.shape(x)[1], :] for layer in self._layers: x, self_att, enc_att = layer(x, encoder_output, look_ahead_mask, padding_mask, training) return x # ========================================= # M A S K S # ========================================= def create_padding_mask(seq): """ For self-attention seq shape(bs, max_length, emb_dim) output shape (bs, max_length, max_length) """ mask = tf.cast(tf.not_equal(seq, 0), tf.bool) mask = tf.reduce_any(mask, 2) mask = tf.repeat(mask, seq.shape[1], 0) mask = tf.reshape(mask, (-1,seq.shape[1], seq.shape[1])) return tf.cast(mask, tf.float32) def create_cross_padding_mask(seq, target_seq): """ For cross-attention seq shape(bs, k, image_features) target_seq(bs, max_length, emb_dim) output shape (bs, max_length, k) """ mask = tf.cast(tf.not_equal(target_seq, 0), tf.bool) mask = tf.reduce_any(mask, 2) mask = tf.repeat(mask, seq.shape[1], 0) mask = tf.reshape(mask, (-1, tf.shape(seq)[1], tf.shape(target_seq)[1])) mask = tf.transpose(mask, [0, 2, 1]) return mask def create_look_ahead_mask(seq): """ seq shape(bs, max_length, emb_dim) output 2D matrix of shape (bs, max_length, max_length) with ones on the diagonal and below. """ size = seq.shape[1] mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0) mask = tf.expand_dims(mask, 0) mask = tf.repeat(mask, tf.shape(seq)[0], 0) return mask def create_masks(seq, target_seq): decoder_mask = create_padding_mask(target_seq) decoder_mask *= create_look_ahead_mask(target_seq) cross_att_mask = create_cross_padding_mask(seq, target_seq) return decoder_mask, cross_att_mask def create_masks_looking_ahead(seq, target_seq): decoder_mask = create_padding_mask(target_seq) cross_att_mask = create_cross_padding_mask(seq, target_seq) return decoder_mask, cross_att_mask # ========================================= # P O S I T I O N A L E N C O D I N G # ========================================= def get_angles(pos, i, d_model): angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model)) return pos * angle_rates @tf.autograph.experimental.do_not_convert def positional_encoding(position, d_model): angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) # apply sin to even indices in the array; 2i angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) # apply cos to odd indices in the array; 2i+1 angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) pos_encoding = angle_rads[np.newaxis, ...] return tf.cast(pos_encoding, dtype=tf.float32) class PatchEncoder(Layer): def __init__(self, num_patches, projection_dim, **kwargs): super(PatchEncoder, self).__init__(**kwargs) self.num_patches = num_patches self.projection_dim = projection_dim self.projection = Dense(units=projection_dim) self.position_embedding = Embedding( input_dim=num_patches, output_dim=projection_dim ) def get_config(self): config = super(PatchEncoder, self).get_config() config.update({"num_patches": self.num_patches, "projection_dim":self.projection_dim}) return config def call(self, patch): positions = tf.range(start=0, limit=self.num_patches, delta=1) encoded = self.projection(patch) + self.position_embedding(positions) return encoded