fbrynpk commited on
Commit
388a840
1 Parent(s): ec13b06

Update Models Function

Browse files
Files changed (1) hide show
  1. model.py +229 -1
model.py CHANGED
@@ -10,7 +10,7 @@ EMBEDDING_DIM = 512
10
  UNITS = 512
11
 
12
 
13
- # LOADING DATA
14
  vocab = pickle.load(open('vocabulary/vocab_coco.file', 'rb'))
15
 
16
  tokenizer = tf.keras.layers.TextVectorization(
@@ -25,6 +25,234 @@ idx2word = tf.keras.layers.StringLookup(
25
  invert = True
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def load_image_from_path(img_path):
29
  img = tf.io.read_file(img_path)
30
  img = tf.io.decode_jpeg(img, channels=3)
 
10
  UNITS = 512
11
 
12
 
13
+ #LOAD VOCAB FOLDER
14
  vocab = pickle.load(open('vocabulary/vocab_coco.file', 'rb'))
15
 
16
  tokenizer = tf.keras.layers.TextVectorization(
 
25
  invert = True
26
  )
27
 
28
+ # CREATING MODEL BASED ON KERAS
29
+ def CNN_Encoder():
30
+ inception_v3 = tf.keras.applications.InceptionV3(
31
+ include_top=False,
32
+ weights='imagenet'
33
+ )
34
+
35
+ output = inception_v3.output
36
+ output = tf.keras.layers.Reshape(
37
+ (-1, output.shape[-1]))(output)
38
+
39
+ cnn_model = tf.keras.models.Model(inception_v3.input, output)
40
+ return cnn_model
41
+
42
+
43
+ class TransformerEncoderLayer(tf.keras.layers.Layer):
44
+
45
+ def __init__(self, embed_dim, num_heads):
46
+ super().__init__()
47
+ self.layer_norm_1 = tf.keras.layers.LayerNormalization()
48
+ self.layer_norm_2 = tf.keras.layers.LayerNormalization()
49
+ self.attention = tf.keras.layers.MultiHeadAttention(
50
+ num_heads=num_heads, key_dim=embed_dim)
51
+ self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
52
+
53
+
54
+ def call(self, x, training):
55
+ x = self.layer_norm_1(x)
56
+ x = self.dense(x)
57
+
58
+ attn_output = self.attention(
59
+ query=x,
60
+ value=x,
61
+ key=x,
62
+ attention_mask=None,
63
+ training=training
64
+ )
65
+
66
+ x = self.layer_norm_2(x + attn_output)
67
+ return x
68
+
69
+
70
+ class Embeddings(tf.keras.layers.Layer):
71
+
72
+ def __init__(self, vocab_size, embed_dim, max_len):
73
+ super().__init__()
74
+ self.token_embeddings = tf.keras.layers.Embedding(
75
+ vocab_size, embed_dim)
76
+ self.position_embeddings = tf.keras.layers.Embedding(
77
+ max_len, embed_dim, input_shape=(None, max_len))
78
+
79
+
80
+ def call(self, input_ids):
81
+ length = tf.shape(input_ids)[-1]
82
+ position_ids = tf.range(start=0, limit=length, delta=1)
83
+ position_ids = tf.expand_dims(position_ids, axis=0)
84
+
85
+ token_embeddings = self.token_embeddings(input_ids)
86
+ position_embeddings = self.position_embeddings(position_ids)
87
+
88
+ return token_embeddings + position_embeddings
89
+
90
+
91
+ class TransformerDecoderLayer(tf.keras.layers.Layer):
92
+
93
+ def __init__(self, embed_dim, units, num_heads):
94
+ super().__init__()
95
+ self.embedding = Embeddings(
96
+ tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
97
+
98
+ self.attention_1 = tf.keras.layers.MultiHeadAttention(
99
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
100
+ )
101
+ self.attention_2 = tf.keras.layers.MultiHeadAttention(
102
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
103
+ )
104
+
105
+ self.layernorm_1 = tf.keras.layers.LayerNormalization()
106
+ self.layernorm_2 = tf.keras.layers.LayerNormalization()
107
+ self.layernorm_3 = tf.keras.layers.LayerNormalization()
108
+
109
+ self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
110
+ self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
111
+
112
+ self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
113
+
114
+ self.dropout_1 = tf.keras.layers.Dropout(0.3)
115
+ self.dropout_2 = tf.keras.layers.Dropout(0.5)
116
+
117
+
118
+ def call(self, input_ids, encoder_output, training, mask=None):
119
+ embeddings = self.embedding(input_ids)
120
+
121
+ combined_mask = None
122
+ padding_mask = None
123
+
124
+ if mask is not None:
125
+ causal_mask = self.get_causal_attention_mask(embeddings)
126
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
127
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
128
+ combined_mask = tf.minimum(combined_mask, causal_mask)
129
+
130
+ attn_output_1 = self.attention_1(
131
+ query=embeddings,
132
+ value=embeddings,
133
+ key=embeddings,
134
+ attention_mask=combined_mask,
135
+ training=training
136
+ )
137
+
138
+ out_1 = self.layernorm_1(embeddings + attn_output_1)
139
+
140
+ attn_output_2 = self.attention_2(
141
+ query=out_1,
142
+ value=encoder_output,
143
+ key=encoder_output,
144
+ attention_mask=padding_mask,
145
+ training=training
146
+ )
147
+
148
+ out_2 = self.layernorm_2(out_1 + attn_output_2)
149
+
150
+ ffn_out = self.ffn_layer_1(out_2)
151
+ ffn_out = self.dropout_1(ffn_out, training=training)
152
+ ffn_out = self.ffn_layer_2(ffn_out)
153
+
154
+ ffn_out = self.layernorm_3(ffn_out + out_2)
155
+ ffn_out = self.dropout_2(ffn_out, training=training)
156
+ preds = self.out(ffn_out)
157
+ return preds
158
+
159
+
160
+ def get_causal_attention_mask(self, inputs):
161
+ input_shape = tf.shape(inputs)
162
+ batch_size, sequence_length = input_shape[0], input_shape[1]
163
+ i = tf.range(sequence_length)[:, tf.newaxis]
164
+ j = tf.range(sequence_length)
165
+ mask = tf.cast(i >= j, dtype="int32")
166
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
167
+ mult = tf.concat(
168
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
169
+ axis=0
170
+ )
171
+ return tf.tile(mask, mult)
172
+
173
+
174
+ class ImageCaptioningModel(tf.keras.Model):
175
+
176
+ def __init__(self, cnn_model, encoder, decoder, image_aug=None):
177
+ super().__init__()
178
+ self.cnn_model = cnn_model
179
+ self.encoder = encoder
180
+ self.decoder = decoder
181
+ self.image_aug = image_aug
182
+ self.loss_tracker = tf.keras.metrics.Mean(name="loss")
183
+ self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
184
+
185
+
186
+ def calculate_loss(self, y_true, y_pred, mask):
187
+ loss = self.loss(y_true, y_pred)
188
+ mask = tf.cast(mask, dtype=loss.dtype)
189
+ loss *= mask
190
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
191
+
192
+
193
+ def calculate_accuracy(self, y_true, y_pred, mask):
194
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
195
+ accuracy = tf.math.logical_and(mask, accuracy)
196
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
197
+ mask = tf.cast(mask, dtype=tf.float32)
198
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
199
+
200
+
201
+ def compute_loss_and_acc(self, img_embed, captions, training=True):
202
+ encoder_output = self.encoder(img_embed, training=True)
203
+ y_input = captions[:, :-1]
204
+ y_true = captions[:, 1:]
205
+ mask = (y_true != 0)
206
+ y_pred = self.decoder(
207
+ y_input, encoder_output, training=True, mask=mask
208
+ )
209
+ loss = self.calculate_loss(y_true, y_pred, mask)
210
+ acc = self.calculate_accuracy(y_true, y_pred, mask)
211
+ return loss, acc
212
+
213
+
214
+ def train_step(self, batch):
215
+ imgs, captions = batch
216
+
217
+ if self.image_aug:
218
+ imgs = self.image_aug(imgs)
219
+
220
+ img_embed = self.cnn_model(imgs)
221
+
222
+ with tf.GradientTape() as tape:
223
+ loss, acc = self.compute_loss_and_acc(
224
+ img_embed, captions
225
+ )
226
+
227
+ train_vars = (
228
+ self.encoder.trainable_variables + self.decoder.trainable_variables
229
+ )
230
+ grads = tape.gradient(loss, train_vars)
231
+ self.optimizer.apply_gradients(zip(grads, train_vars))
232
+ self.loss_tracker.update_state(loss)
233
+ self.acc_tracker.update_state(acc)
234
+
235
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
236
+
237
+
238
+ def test_step(self, batch):
239
+ imgs, captions = batch
240
+
241
+ img_embed = self.cnn_model(imgs)
242
+
243
+ loss, acc = self.compute_loss_and_acc(
244
+ img_embed, captions, training=False
245
+ )
246
+
247
+ self.loss_tracker.update_state(loss)
248
+ self.acc_tracker.update_state(acc)
249
+
250
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
251
+
252
+ @property
253
+ def metrics(self):
254
+ return [self.loss_tracker, self.acc_tracker]
255
+
256
  def load_image_from_path(img_path):
257
  img = tf.io.read_file(img_path)
258
  img = tf.io.decode_jpeg(img, channels=3)