File size: 2,333 Bytes
7dd3871
 
 
 
 
 
 
 
 
a134d0c
8ab6fdb
a134d0c
8ab6fdb
7dd3871
 
8ab6fdb
7dd3871
 
 
8ab6fdb
7dd3871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1199618
 
7dd3871
8ab6fdb
7dd3871
 
 
8ab6fdb
7dd3871
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import tensorflow as tf
import numpy as np 
from tensorflow import keras
import os
from typing import Dict, List, Any
import pickle
from PIL import Image

class PreTrainedPipeline():
    def __init__(self, path: str):
        
        self.model = keras.models.load_model(os.path.join(path, "model"))

        self.word_to_index = tf.keras.layers.StringLookup(
          mask_token="",
          vocabulary=self.model.tokenizer.get_vocabulary())
        
        self.index_to_word = tf.keras.layers.StringLookup(
          mask_token="",
          vocabulary=self.model.tokenizer.get_vocabulary(),
          invert=True)

    def load_image(img):
        #img = tf.io.read_file(image_path)
        img = tf.io.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, IMAGE_SHAPE[:-1])
        return img

    def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
        """
        Args:
            inputs (:obj:`PIL.Image`):
                The raw image representation as PIL.
                No transformation made whatsoever from the input. Make all necessary transformations here.
        Return:
            A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
                It is preferred if the returned list is in decreasing `score` order
        """
        img_array = tf.keras.utils.img_to_array(inputs)
        image = load_image(img_array)
        initial = self.word_to_index([['[START]']]) # (batch, sequence)
        img_features = self.model.feature_extractor(image[tf.newaxis, ...])
        temperature = 0
        tokens = initial # (batch, sequence)
        for n in range(50):
          preds = self.model((img_features, tokens)).numpy()  # (batch, sequence, vocab)
          preds = preds[:,-1, :]  #(batch, vocab)
          if temperature==0:
              next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
          else:
              next = tf.random.categorical(preds/temperature, num_samples=1)  # (batch, 1)
          tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)

          if next[0] == self.word_to_index('[END]'):
            break
        words = self.index_to_word(tokens[0, 1:-1])
        result = tf.strings.reduce_join(words, axis=-1, separator=' ')
        return result.numpy().decode()