File size: 2,333 Bytes
7dd3871 a134d0c 8ab6fdb a134d0c 8ab6fdb 7dd3871 8ab6fdb 7dd3871 8ab6fdb 7dd3871 1199618 7dd3871 8ab6fdb 7dd3871 8ab6fdb 7dd3871 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
from typing import Dict, List, Any
import pickle
from PIL import Image
class PreTrainedPipeline():
def __init__(self, path: str):
self.model = keras.models.load_model(os.path.join(path, "model"))
self.word_to_index = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=self.model.tokenizer.get_vocabulary())
self.index_to_word = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=self.model.tokenizer.get_vocabulary(),
invert=True)
def load_image(img):
#img = tf.io.read_file(image_path)
img = tf.io.decode_jpeg(img, channels=3)
img = tf.image.resize(img, IMAGE_SHAPE[:-1])
return img
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
It is preferred if the returned list is in decreasing `score` order
"""
img_array = tf.keras.utils.img_to_array(inputs)
image = load_image(img_array)
initial = self.word_to_index([['[START]']]) # (batch, sequence)
img_features = self.model.feature_extractor(image[tf.newaxis, ...])
temperature = 0
tokens = initial # (batch, sequence)
for n in range(50):
preds = self.model((img_features, tokens)).numpy() # (batch, sequence, vocab)
preds = preds[:,-1, :] #(batch, vocab)
if temperature==0:
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
else:
next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
if next[0] == self.word_to_index('[END]'):
break
words = self.index_to_word(tokens[0, 1:-1])
result = tf.strings.reduce_join(words, axis=-1, separator=' ')
return result.numpy().decode()
|