fbrynpk's picture
Create function to call models
d1de1d0
raw
history blame
No virus
2.7 kB
import pickle
import tensorflow as tf
import pandas as pd
import numpy as np
MAX_LENGTH = 40
BATCH_SIZE = 32
BUFFER_SIZE = 1000
EMBEDDING_DIM = 512
UNITS = 512
# LOADING DATA
vocab = pickle.load(open('vocabulary/vocab_coco.file', 'rb'))
tokenizer = tf.keras.layers.TextVectorization(
standardize = None,
output_sequence_length = MAX_LENGTH,
vocabulary = vocab
)
idx2word = tf.keras.layers.StringLookup(
mask_token = "",
vocabulary = tokenizer.get_vocabulary(),
invert = True
)
def load_image_from_path(img_path):
img = tf.io.read_file(img_path)
img = tf.io.decode_jpeg(img, channels=3)
img = tf.keras.layers.Resizing(299, 299)(img)
img = tf.keras.applications.inception_v3.preprocess_input(img)
return img
def generate_caption(img, caption_model, add_noise=False):
if isinstance(img, str):
img = load_image_from_path(img)
if add_noise == True:
noise = tf.random.normal(img.shape)*0.1
img = (img + noise)
img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
img = tf.expand_dims(img, axis=0)
img_embed = caption_model.cnn_model(img)
img_encoded = caption_model.encoder(img_embed, training=False)
y_inp = '[start]'
for i in range(MAX_LENGTH-1):
tokenized = tokenizer([y_inp])[:, :-1]
mask = tf.cast(tokenized != 0, tf.int32)
pred = caption_model.decoder(
tokenized, img_encoded, training=False, mask=mask)
pred_idx = np.argmax(pred[0, i, :])
pred_word = idx2word(pred_idx).numpy().decode('utf-8')
if pred_word == '[end]':
break
y_inp += ' ' + pred_word
y_inp = y_inp.replace('[start] ', '')
return y_inp
def get_caption_model():
encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
cnn_model = CNN_Encoder()
caption_model = ImageCaptioningModel(
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
)
def call_fn(batch, training):
return batch
caption_model.call = call_fn
sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
caption_model((sample_x, sample_y))
sample_img_embed = caption_model.cnn_model(sample_x)
sample_enc_out = caption_model.encoder(sample_img_embed, training=False)
caption_model.decoder(sample_y, sample_enc_out, training=False)
try:
caption_model.load_weights('models/trained_coco_weights.h5')
except FileNotFoundError:
caption_model.load_weights('image-caption-generator/models/trained_coco_weights.h5')
return caption_model