from transformers import TFPreTrainedModel, PreTrainedTokenizer, BatchEncoding from tensorflow.keras.models import Model, load_model, Sequential from tensorflow.keras.layers import Layer, Dense, concatenate, Input, add, Dropout, LayerNormalization, MultiHeadAttention, Embedding import tensorflow as tf import numpy as np from typing import Dict import re import unicodedata from .configuration_bilma import BilmaConfig # copied from preprocessing.py BLANK = ' ' RE_OPS = re.I | re.M | re.S RE_USR = re.compile(r"""@\S+""", RE_OPS) RE_TAG = re.compile(r"""#\S+""", RE_OPS) RE_URL = re.compile(r"""(http|ftp|https)://\S+""", RE_OPS) RE_NUM = re.compile(r"""[-+]?\d+\.?\d*""", RE_OPS) SYMBOLS_ = "()[]¿?¡!{}~<>|" SYMBOLS = set(";:,.@\\-\"/" + SYMBOLS_) # ------------------ # Class declaration # ------------------ class TFBilma(TFPreTrainedModel): config_class = BilmaConfig main_input_name = "input_ids" #base_model_prefix = "bilma" def __init__(self, config): self.seq_max_length = config.seq_max_length self.include_top = config.include_top self.add_head = config.add_head super().__init__(config) self.model = bilma(num_enc=config.num_hidden_layers, embed_dim=config.hidden_size, max_length=config.seq_max_length, num_heads=config.num_attention_heads, ff_dim=config.hidden_size, vocab_size=config.vocab_size, rate=config.hidden_dropout_prob, include_top = config.include_top, add_head = config.add_head, pooling = config.pooling) @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: dummies = {} for key, spec in self.input_signature.items(): dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] if spec.shape[0] is None: dummy_shape[0] = 1 dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype) return dummies @property def input_signature(self) -> Dict[str, tf.TensorSpec]: sig = {} sig["input_ids"] = tf.TensorSpec([None, self.seq_max_length], tf.int32, name="input_ids") return sig def call(self, inputs): if isinstance(inputs, Dict) or isinstance(inputs, BatchEncoding): ins = tf.cast(inputs["input_ids"], tf.float32) else: ins = inputs if self.include_top: output = {"logits":self.model(ins)} else: if self.add_head is None: output = {"last_hidden_state":self.model(ins)} else: output = {"label":self.model(ins)} return output def get_loss_function(): return loss_funtion() def get_acc_function(): return accuracy_function() # copied from bilma_model.py # -------------------------- def loss_function(ignore_id=0): loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none') def loss(real, pred): mask = tf.math.logical_not(tf.math.equal(real, ignore_id)) loss_ = loss_object(real, pred) mask = tf.cast(mask, dtype=loss_.dtype) loss_ *= mask sum_ = tf.reduce_sum(mask,axis=1) loss_ = tf.math.divide_no_nan(tf.reduce_sum(loss_, axis=1), sum_) return loss_ return loss def accuracy_function(ignore_id=0): def acc_mlm(real, pred): accuracies = tf.equal(tf.cast(real, tf.int64), tf.argmax(pred, axis=2)) mask = tf.math.logical_not(tf.math.equal(real, ignore_id)) accuracies = tf.math.logical_and(mask, accuracies) accuracies = tf.cast(accuracies, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32) return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask)) return acc_mlm def mean_vectors(inputs, enc_vectors, max_length): p = tf.where(inputs == 3) pos = tf.transpose(p)[1] C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32) C = tf.reshape(C, (-1, max_length, 1)) S = tf.reduce_sum(enc_vectors * C, 1) x = S / tf.expand_dims(tf.cast(pos, tf.float32), (1)) return x def mean_diff_vectors(inputs, enc_vectors, max_length): p = tf.where(inputs == 3) pos = tf.transpose(p)[1] C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32) C = tf.reshape(C, (-1, max_length, 1)) vecs = enc_vectors * C S = tf.reduce_sum(vecs, 1) mu = S / tf.expand_dims(tf.cast(pos, tf.float32), (1)) x = tf.reduce_sum(mu - vecs, 1) / tf.expand_dims(tf.cast(pos, tf.float32), (1)) return x def max_vectors(inputs, enc_vectors, max_length): p = tf.where(inputs == 3) pos = tf.transpose(p)[1] C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32) C = tf.reshape(C, (-1, max_length, 1)) x = tf.reduce_max(enc_vectors * C, 1) return x def cls_vectors(inputs, enc_vectors, max_length): x = tf.squeeze(enc_vectors[:, 0:1, :], axis=1) return x def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True, add_head=None, pooling=None): capt_inputs_ids = Input(shape=(max_length, ), name='input_ids') capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding") capt_inputs = capt_embedding(capt_inputs_ids) enc = Encoder(num_enc, embed_dim, max_length, num_heads, ff_dim, rate=rate, name="bilma/encoder") enc_output = enc(capt_inputs) if include_top: fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output) else: x = enc_output if pooling == "mean": x = mean_vectors(capt_inputs_ids, x, max_length) elif pooling == "cls": x = cls_vectors(capt_inputs_ids, x, max_length) elif pooling == "max": x = max_vectors(capt_inputs_ids, x, max_length) if add_head is None: fin_output = x else: for i, m in enumerate(add_head[:-1]): x = Dense(m, use_bias=True, activation="relu", name=f"bilma/dense_ex_{i}")(x) fin_output = Dense(add_head[-1], use_bias=True, activation="softmax", name=f"bilma/dense_ex_final")(x) caption_model = Model(inputs=capt_inputs_ids, outputs=fin_output, name="bilma_model") return caption_model def load(model_file): custom_objects={"EncoderBlock": EncoderBlock, "Encoder": Encoder, "loss": loss_function(), "acc_mlm":accuracy_function(), } return load_model(model_file, custom_objects=custom_objects) # # Copied from transformer_text.py # ------------------------------- class EncoderBlock(Layer): def __init__(self, layer_num, patch_dim, num_heads, ff_dim, rate=0.1, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.ln = layer_num self.p_d = patch_dim self.n_h = num_heads self.f_d = ff_dim self.rate = rate self.att = MultiHeadAttention(num_heads=num_heads, key_dim=patch_dim, name=f"bilma/MHA_{layer_num}") self.ffn = Sequential( #[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu), # Conv1D(patch_dim, kernel_size=1),] [Dense(ff_dim, activation=tf.nn.gelu, name=f"bilma/dense1_{layer_num}"), Dense(patch_dim, name=f"bilma/dense2_{layer_num}")] ) #self.layernorm0 = LayerNormalization(epsilon=1e-6) self.layernorm1 = LayerNormalization(epsilon=1e-6, name=f"ln1_{layer_num}") self.layernorm2 = LayerNormalization(epsilon=1e-6, name=f"ln2_{layer_num}") self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) def get_config(self): config = super(EncoderBlock, self).get_config() config.update({"layer_num":self.ln, "patch_dim":self.p_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, inputs, training=False): #inputs = self.layernorm0(inputs) attn_output = self.att(inputs, inputs) attn_output = self.dropout1(attn_output, training=training) out1 = self.layernorm1(add([inputs, attn_output])) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) return self.layernorm2(add([out1, ffn_output])) class DecoderBlock(Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.e_d = embed_dim self.n_h = num_heads self.f_d = ff_dim self.rate = rate self.att1 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.att2 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.ffn = Sequential( #[Conv1D(ff_dim, kernel_size=1, activation=tf.nn.gelu), # Conv1D(embed_dim, kernel_size=1),] [Dense(ff_dim, activation=tf.nn.gelu), Dense(embed_dim),] ) self.layernorm1 = LayerNormalization(epsilon=1e-6) self.layernorm2 = LayerNormalization(epsilon=1e-6) self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) self.dropout3 = Dropout(rate) def get_config(self): config = super(DecoderBlock, self).get_config() config.update({"embed_dim":self.e_d, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, inputs, encoder_output, look_ahead_mask, padding_mask, training=None): y, attn_output1 = self.att1(inputs, inputs, attention_mask=look_ahead_mask, return_attention_scores=True) y = self.dropout1(y, training=training) y = add([inputs, y]) out1 = self.layernorm1(y) y, attn_encoder = self.att2(out1, encoder_output, attention_mask=padding_mask, return_attention_scores=True) y = self.dropout2(y, training=training) y = add([out1, y]) out2 = self.layernorm1(y) ffn_output = self.ffn(out2) ffn_output = self.dropout3(ffn_output, training=training) final_output = self.layernorm2(out2 + ffn_output) return final_output, attn_output1, attn_encoder class Encoder(Layer): def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs): super(Encoder, self).__init__(**kwargs) self.n = n self.embed_dim = embed_dim self.max_length = max_length self.n_h = num_heads self.f_d = ff_dim self.rate = rate self._layers = [EncoderBlock(i, embed_dim, num_heads, ff_dim, rate=0.1, name=f"enc_block_{i}") for i in range(n)] self.pe = positional_encoding(self.max_length, self.embed_dim) def get_config(self): config = super(Encoder, self).get_config() config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, x, training=False): x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32)) x = x + self.pe[:, :tf.shape(x)[1], :] for layer in self._layers: x = layer(x, training) return x class Decoder(Layer): def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs): super(Decoder, self).__init__(**kwargs) self.n = n self.embed_dim = embed_dim self.max_length = max_length self.n_h = num_heads self.f_d = ff_dim self.rate = rate self._layers = [DecoderBlock(embed_dim, num_heads, ff_dim, rate=0.1) for _ in range(n)] self.pe = positional_encoding(self.max_length, self.embed_dim) def get_config(self): config = super(Decoder, self).get_config() config.update({"n": self.n, "embed_dim":self.embed_dim, "max_length": self.max_length, "num_heads":self.n_h, "ff_dim":self.f_d, "rate":self.rate}) return config def call(self, x, encoder_output, look_ahead_mask, padding_mask, training): x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32)) x = x + self.pe[:, :tf.shape(x)[1], :] for layer in self._layers: x, self_att, enc_att = layer(x, encoder_output, look_ahead_mask, padding_mask, training) return x # ========================================= # M A S K S # ========================================= def create_padding_mask(seq): """ For self-attention seq shape(bs, max_length, emb_dim) output shape (bs, max_length, max_length) """ mask = tf.cast(tf.not_equal(seq, 0), tf.bool) mask = tf.reduce_any(mask, 2) mask = tf.repeat(mask, seq.shape[1], 0) mask = tf.reshape(mask, (-1,seq.shape[1], seq.shape[1])) return tf.cast(mask, tf.float32) def create_cross_padding_mask(seq, target_seq): """ For cross-attention seq shape(bs, k, image_features) target_seq(bs, max_length, emb_dim) output shape (bs, max_length, k) """ mask = tf.cast(tf.not_equal(target_seq, 0), tf.bool) mask = tf.reduce_any(mask, 2) mask = tf.repeat(mask, seq.shape[1], 0) mask = tf.reshape(mask, (-1, tf.shape(seq)[1], tf.shape(target_seq)[1])) mask = tf.transpose(mask, [0, 2, 1]) return mask def create_look_ahead_mask(seq): """ seq shape(bs, max_length, emb_dim) output 2D matrix of shape (bs, max_length, max_length) with ones on the diagonal and below. """ size = seq.shape[1] mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0) mask = tf.expand_dims(mask, 0) mask = tf.repeat(mask, tf.shape(seq)[0], 0) return mask def create_masks(seq, target_seq): decoder_mask = create_padding_mask(target_seq) decoder_mask *= create_look_ahead_mask(target_seq) cross_att_mask = create_cross_padding_mask(seq, target_seq) return decoder_mask, cross_att_mask def create_masks_looking_ahead(seq, target_seq): decoder_mask = create_padding_mask(target_seq) cross_att_mask = create_cross_padding_mask(seq, target_seq) return decoder_mask, cross_att_mask # ========================================= # P O S I T I O N A L E N C O D I N G # ========================================= def get_angles(pos, i, d_model): angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model)) return pos * angle_rates @tf.autograph.experimental.do_not_convert def positional_encoding(position, d_model): angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) # apply sin to even indices in the array; 2i angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) # apply cos to odd indices in the array; 2i+1 angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) pos_encoding = angle_rads[np.newaxis, ...] return tf.cast(pos_encoding, dtype=tf.float32) class PatchEncoder(Layer): def __init__(self, num_patches, projection_dim, **kwargs): super(PatchEncoder, self).__init__(**kwargs) self.num_patches = num_patches self.projection_dim = projection_dim self.projection = Dense(units=projection_dim) self.position_embedding = Embedding( input_dim=num_patches, output_dim=projection_dim ) def get_config(self): config = super(PatchEncoder, self).get_config() config.update({"num_patches": self.num_patches, "projection_dim":self.projection_dim}) return config def call(self, patch): positions = tf.range(start=0, limit=self.num_patches, delta=1) encoded = self.projection(patch) + self.position_embedding(positions) return encoded