img-caption-demo / pipeline.py
hoangthan's picture
Update pipeline.py
1199618
import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
from typing import Dict, List, Any
import pickle
from PIL import Image
class PreTrainedPipeline():
def __init__(self, path: str):
self.model = keras.models.load_model(os.path.join(path, "model"))
self.word_to_index = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=self.model.tokenizer.get_vocabulary())
self.index_to_word = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=self.model.tokenizer.get_vocabulary(),
invert=True)
def load_image(img):
#img = tf.io.read_file(image_path)
img = tf.io.decode_jpeg(img, channels=3)
img = tf.image.resize(img, IMAGE_SHAPE[:-1])
return img
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
It is preferred if the returned list is in decreasing `score` order
"""
img_array = tf.keras.utils.img_to_array(inputs)
image = load_image(img_array)
initial = self.word_to_index([['[START]']]) # (batch, sequence)
img_features = self.model.feature_extractor(image[tf.newaxis, ...])
temperature = 0
tokens = initial # (batch, sequence)
for n in range(50):
preds = self.model((img_features, tokens)).numpy() # (batch, sequence, vocab)
preds = preds[:,-1, :] #(batch, vocab)
if temperature==0:
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
else:
next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
if next[0] == self.word_to_index('[END]'):
break
words = self.index_to_word(tokens[0, 1:-1])
result = tf.strings.reduce_join(words, axis=-1, separator=' ')
return result.numpy().decode()