fbrynpk commited on
Commit
d1de1d0
1 Parent(s): 0206028

Create function to call models

Browse files
Files changed (1) hide show
  1. model.py +94 -0
model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ MAX_LENGTH = 40
7
+ BATCH_SIZE = 32
8
+ BUFFER_SIZE = 1000
9
+ EMBEDDING_DIM = 512
10
+ UNITS = 512
11
+
12
+
13
+ # LOADING DATA
14
+ vocab = pickle.load(open('vocabulary/vocab_coco.file', 'rb'))
15
+
16
+ tokenizer = tf.keras.layers.TextVectorization(
17
+ standardize = None,
18
+ output_sequence_length = MAX_LENGTH,
19
+ vocabulary = vocab
20
+ )
21
+
22
+ idx2word = tf.keras.layers.StringLookup(
23
+ mask_token = "",
24
+ vocabulary = tokenizer.get_vocabulary(),
25
+ invert = True
26
+ )
27
+
28
+ def load_image_from_path(img_path):
29
+ img = tf.io.read_file(img_path)
30
+ img = tf.io.decode_jpeg(img, channels=3)
31
+ img = tf.keras.layers.Resizing(299, 299)(img)
32
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
33
+ return img
34
+
35
+
36
+ def generate_caption(img, caption_model, add_noise=False):
37
+ if isinstance(img, str):
38
+ img = load_image_from_path(img)
39
+
40
+ if add_noise == True:
41
+ noise = tf.random.normal(img.shape)*0.1
42
+ img = (img + noise)
43
+ img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
44
+
45
+ img = tf.expand_dims(img, axis=0)
46
+ img_embed = caption_model.cnn_model(img)
47
+ img_encoded = caption_model.encoder(img_embed, training=False)
48
+
49
+ y_inp = '[start]'
50
+ for i in range(MAX_LENGTH-1):
51
+ tokenized = tokenizer([y_inp])[:, :-1]
52
+ mask = tf.cast(tokenized != 0, tf.int32)
53
+ pred = caption_model.decoder(
54
+ tokenized, img_encoded, training=False, mask=mask)
55
+
56
+ pred_idx = np.argmax(pred[0, i, :])
57
+ pred_word = idx2word(pred_idx).numpy().decode('utf-8')
58
+ if pred_word == '[end]':
59
+ break
60
+
61
+ y_inp += ' ' + pred_word
62
+
63
+ y_inp = y_inp.replace('[start] ', '')
64
+ return y_inp
65
+
66
+
67
+ def get_caption_model():
68
+ encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
69
+ decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
70
+
71
+ cnn_model = CNN_Encoder()
72
+
73
+ caption_model = ImageCaptioningModel(
74
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
75
+ )
76
+
77
+ def call_fn(batch, training):
78
+ return batch
79
+
80
+ caption_model.call = call_fn
81
+ sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
82
+
83
+ caption_model((sample_x, sample_y))
84
+
85
+ sample_img_embed = caption_model.cnn_model(sample_x)
86
+ sample_enc_out = caption_model.encoder(sample_img_embed, training=False)
87
+ caption_model.decoder(sample_y, sample_enc_out, training=False)
88
+
89
+ try:
90
+ caption_model.load_weights('models/trained_coco_weights.h5')
91
+ except FileNotFoundError:
92
+ caption_model.load_weights('image-caption-generator/models/trained_coco_weights.h5')
93
+
94
+ return caption_model