guillermoruiz commited on
Commit
f1871ce
1 Parent(s): 866870d

Upload TFBilma

Browse files
Files changed (4) hide show
  1. config.json +0 -1
  2. configuration_bilma.py +14 -1
  3. modeling_bilma.py +70 -9
  4. tf_model.h5 +1 -1
config.json CHANGED
@@ -1,5 +1,4 @@
1
  {
2
- "_name_or_path": "bilma_AR",
3
  "add_head": null,
4
  "architectures": [
5
  "Bilma"
 
1
  {
 
2
  "add_head": null,
3
  "architectures": [
4
  "Bilma"
configuration_bilma.py CHANGED
@@ -6,7 +6,9 @@ class BilmaConfig(PretrainedConfig):
6
  def __init__(
7
  self,
8
  weights="AR",
9
- include_top=True,
 
 
10
  num_attention_heads: int = 4,
11
  num_hidden_layers: int = 2,
12
  seq_max_length: int = 280,
@@ -16,11 +18,20 @@ class BilmaConfig(PretrainedConfig):
16
  **kwargs,
17
  ):
18
  countries = ["AR"]
 
19
  if weights not in countries:
20
  raise ValueError(f"`weights` must be one of {countries}, got {weights}.")
 
 
 
 
 
 
21
  if weights is not None:
22
  self.weights = weights
23
  self.include_top = include_top
 
 
24
  self.num_attention_heads = 4
25
  self.num_hidden_layers = 2
26
  self.seq_max_length = 280
@@ -32,6 +43,8 @@ class BilmaConfig(PretrainedConfig):
32
 
33
  self.weights = weights
34
  self.include_top = include_top
 
 
35
  self.num_attention_heads = num_attention_heads
36
  self.num_hidden_layers = num_hidden_layers
37
  self.seq_max_length = seq_max_length
 
6
  def __init__(
7
  self,
8
  weights="AR",
9
+ include_top = True,
10
+ add_head = None,
11
+ pooling = None,
12
  num_attention_heads: int = 4,
13
  num_hidden_layers: int = 2,
14
  seq_max_length: int = 280,
 
18
  **kwargs,
19
  ):
20
  countries = ["AR"]
21
+ poolings = ["mean", "cls", "max"]
22
  if weights not in countries:
23
  raise ValueError(f"`weights` must be one of {countries}, got {weights}.")
24
+ if add_head is not None and include_top == True:
25
+ raise ValueError(f"To add a head, 'include_top' must be False")
26
+ if pooling is not None and include_top == True:
27
+ raise ValueError(f"To specify a pooling, 'include_top' must be False")
28
+ if pooling is not None and pooling not in poolings:
29
+ raise ValueError(f"`pooling` must be one of {poolings}, got {pooling}.")
30
  if weights is not None:
31
  self.weights = weights
32
  self.include_top = include_top
33
+ self.add_head = add_head
34
+ self.pooling = pooling
35
  self.num_attention_heads = 4
36
  self.num_hidden_layers = 2
37
  self.seq_max_length = 280
 
43
 
44
  self.weights = weights
45
  self.include_top = include_top
46
+ self.add_head = add_head
47
+ self.pooling = pooling
48
  self.num_attention_heads = num_attention_heads
49
  self.num_hidden_layers = num_hidden_layers
50
  self.seq_max_length = seq_max_length
modeling_bilma.py CHANGED
@@ -1,4 +1,5 @@
1
- from transformers import TFPreTrainedModel, PreTrainedTokenizer
 
2
  from tensorflow.keras.models import Model, load_model, Sequential
3
  from tensorflow.keras.layers import Layer, Dense, concatenate, Input, add, Dropout, LayerNormalization, MultiHeadAttention, Embedding
4
  import tensorflow as tf
@@ -38,6 +39,7 @@ class TFBilma(TFPreTrainedModel):
38
  def __init__(self, config):
39
  self.seq_max_length = config.seq_max_length
40
  self.include_top = config.include_top
 
41
  super().__init__(config)
42
 
43
  self.model = bilma(num_enc=config.num_hidden_layers,
@@ -47,7 +49,9 @@ class TFBilma(TFPreTrainedModel):
47
  ff_dim=config.hidden_size,
48
  vocab_size=config.vocab_size,
49
  rate=config.hidden_dropout_prob,
50
- include_top = config.include_top)
 
 
51
 
52
  @property
53
  def dummy_inputs(self) -> Dict[str, tf.Tensor]:
@@ -70,13 +74,26 @@ class TFBilma(TFPreTrainedModel):
70
 
71
 
72
  def call(self, inputs):
73
- ins = tf.cast(inputs["input_ids"], tf.float32)
 
 
 
74
  if self.include_top:
75
  output = {"logits":self.model(ins)}
76
  else:
77
- output = {"last_hidden_state":self.model(ins)}
 
 
 
78
  return output
79
 
 
 
 
 
 
 
 
80
  # copied from bilma_model.py
81
  # --------------------------
82
 
@@ -105,7 +122,40 @@ def accuracy_function(ignore_id=0):
105
  return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
106
  return acc_mlm
107
 
108
- def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
110
  capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
111
  capt_inputs = capt_embedding(capt_inputs_ids)
@@ -115,9 +165,22 @@ def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, voca
115
  if include_top:
116
  fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
117
  else:
118
- fin_output = enc_output
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- caption_model = Model(inputs=capt_inputs_ids, outputs=[fin_output], name="bilma_model")
121
  return caption_model
122
 
123
  def load(model_file):
@@ -132,7 +195,6 @@ def load(model_file):
132
  #
133
  # Copied from transformer_text.py
134
  # -------------------------------
135
-
136
  class EncoderBlock(Layer):
137
  def __init__(self, layer_num, patch_dim, num_heads, ff_dim, rate=0.1, **kwargs):
138
  super(EncoderBlock, self).__init__(**kwargs)
@@ -214,7 +276,6 @@ class DecoderBlock(Layer):
214
 
215
  return final_output, attn_output1, attn_encoder
216
 
217
-
218
  class Encoder(Layer):
219
  def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs):
220
  super(Encoder, self).__init__(**kwargs)
 
1
+ from transformers import TFPreTrainedModel, PreTrainedTokenizer, BatchEncoding
2
+
3
  from tensorflow.keras.models import Model, load_model, Sequential
4
  from tensorflow.keras.layers import Layer, Dense, concatenate, Input, add, Dropout, LayerNormalization, MultiHeadAttention, Embedding
5
  import tensorflow as tf
 
39
  def __init__(self, config):
40
  self.seq_max_length = config.seq_max_length
41
  self.include_top = config.include_top
42
+ self.add_head = config.add_head
43
  super().__init__(config)
44
 
45
  self.model = bilma(num_enc=config.num_hidden_layers,
 
49
  ff_dim=config.hidden_size,
50
  vocab_size=config.vocab_size,
51
  rate=config.hidden_dropout_prob,
52
+ include_top = config.include_top,
53
+ add_head = config.add_head,
54
+ pooling = config.pooling)
55
 
56
  @property
57
  def dummy_inputs(self) -> Dict[str, tf.Tensor]:
 
74
 
75
 
76
  def call(self, inputs):
77
+ if isinstance(inputs, Dict) or isinstance(inputs, BatchEncoding):
78
+ ins = tf.cast(inputs["input_ids"], tf.float32)
79
+ else:
80
+ ins = inputs
81
  if self.include_top:
82
  output = {"logits":self.model(ins)}
83
  else:
84
+ if self.add_head is None:
85
+ output = {"last_hidden_state":self.model(ins)}
86
+ else:
87
+ output = {"label":self.model(ins)}
88
  return output
89
 
90
+ def get_loss_function():
91
+ return loss_funtion()
92
+
93
+ def get_acc_function():
94
+ return accuracy_function()
95
+
96
+
97
  # copied from bilma_model.py
98
  # --------------------------
99
 
 
122
  return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
123
  return acc_mlm
124
 
125
+ def mean_vectors(inputs, enc_vectors, max_length):
126
+ p = tf.where(inputs == 3)
127
+ pos = tf.transpose(p)[1]
128
+ C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
129
+ C = tf.reshape(C, (-1, max_length, 1))
130
+ S = tf.reduce_sum(enc_vectors * C, 1)
131
+ x = S / tf.expand_dims(tf.cast(pos, tf.float32), (1))
132
+ return x
133
+
134
+ def mean_diff_vectors(inputs, enc_vectors, max_length):
135
+ p = tf.where(inputs == 3)
136
+ pos = tf.transpose(p)[1]
137
+ C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
138
+ C = tf.reshape(C, (-1, max_length, 1))
139
+ vecs = enc_vectors * C
140
+ S = tf.reduce_sum(vecs, 1)
141
+ mu = S / tf.expand_dims(tf.cast(pos, tf.float32), (1))
142
+ x = tf.reduce_sum(mu - vecs, 1) / tf.expand_dims(tf.cast(pos, tf.float32), (1))
143
+ return x
144
+
145
+ def max_vectors(inputs, enc_vectors, max_length):
146
+ p = tf.where(inputs == 3)
147
+ pos = tf.transpose(p)[1]
148
+ C = tf.sequence_mask(pos, maxlen=max_length, dtype=tf.float32)
149
+ C = tf.reshape(C, (-1, max_length, 1))
150
+ x = tf.reduce_max(enc_vectors * C, 1)
151
+ return x
152
+
153
+ def cls_vectors(inputs, enc_vectors, max_length):
154
+ x = tf.squeeze(enc_vectors[:, 0:1, :], axis=1)
155
+ return x
156
+
157
+
158
+ def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True, add_head=None, pooling=None):
159
  capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
160
  capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
161
  capt_inputs = capt_embedding(capt_inputs_ids)
 
165
  if include_top:
166
  fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
167
  else:
168
+ x = enc_output
169
+ if pooling == "mean":
170
+ x = mean_vectors(capt_inputs_ids, x, max_length)
171
+ elif pooling == "cls":
172
+ x = cls_vectors(capt_inputs_ids, x, max_length)
173
+ elif pooling == "max":
174
+ x = max_vectors(capt_inputs_ids, x, max_length)
175
+
176
+ if add_head is None:
177
+ fin_output = x
178
+ else:
179
+ for i, m in enumerate(add_head[:-1]):
180
+ x = Dense(m, use_bias=True, activation="relu", name=f"bilma/dense_ex_{i}")(x)
181
+ fin_output = Dense(add_head[-1], use_bias=True, activation="softmax", name=f"bilma/dense_ex_final")(x)
182
 
183
+ caption_model = Model(inputs=capt_inputs_ids, outputs=fin_output, name="bilma_model")
184
  return caption_model
185
 
186
  def load(model_file):
 
195
  #
196
  # Copied from transformer_text.py
197
  # -------------------------------
 
198
  class EncoderBlock(Layer):
199
  def __init__(self, layer_num, patch_dim, num_heads, ff_dim, rate=0.1, **kwargs):
200
  super(EncoderBlock, self).__init__(**kwargs)
 
276
 
277
  return final_output, attn_output1, attn_encoder
278
 
 
279
  class Encoder(Layer):
280
  def __init__(self, n, embed_dim, max_length, num_heads, ff_dim, rate=0.1, **kwargs):
281
  super(Encoder, self).__init__(**kwargs)
tf_model.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6f092e461b986ede156fd67fcfe8b28a7f360bbadb52b6dcefd19b913575865c
3
  size 156875820
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8ca072bb7a05a66b756aea19821a7003fe14556d2331b00418c187ad4e14701
3
  size 156875820