import tensorflow as tf import numpy as np from tensorflow import keras import os from typing import Dict, List, Any import pickle from PIL import Image class PreTrainedPipeline(): def __init__(self, path: str): self.model = keras.models.load_model(os.path.join(path, "model")) self.word_to_index = tf.keras.layers.StringLookup( mask_token="", vocabulary=self.model.tokenizer.get_vocabulary()) self.index_to_word = tf.keras.layers.StringLookup( mask_token="", vocabulary=self.model.tokenizer.get_vocabulary(), invert=True) def load_image(img): #img = tf.io.read_file(image_path) img = tf.io.decode_jpeg(img, channels=3) img = tf.image.resize(img, IMAGE_SHAPE[:-1]) return img def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: """ Args: inputs (:obj:`PIL.Image`): The raw image representation as PIL. No transformation made whatsoever from the input. Make all necessary transformations here. Return: A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82} It is preferred if the returned list is in decreasing `score` order """ img_array = tf.keras.utils.img_to_array(inputs) image = load_image(img_array) initial = self.word_to_index([['[START]']]) # (batch, sequence) img_features = self.model.feature_extractor(image[tf.newaxis, ...]) temperature = 0 tokens = initial # (batch, sequence) for n in range(50): preds = self.model((img_features, tokens)).numpy() # (batch, sequence, vocab) preds = preds[:,-1, :] #(batch, vocab) if temperature==0: next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1) else: next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1) tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) if next[0] == self.word_to_index('[END]'): break words = self.index_to_word(tokens[0, 1:-1]) result = tf.strings.reduce_join(words, axis=-1, separator=' ') return result.numpy().decode()