hoangthan commited on
Commit
7dd3871
1 Parent(s): 5dd52c8

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +54 -0
pipeline.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from tensorflow import keras
4
+ import os
5
+ from typing import Dict, List, Any
6
+ import pickle
7
+ from PIL import Image
8
+
9
+ class PreTrainedPipeline():
10
+ def __init__(self, path=""):
11
+ self.word_to_index = tf.keras.layers.StringLookup(
12
+ mask_token="",
13
+ vocabulary=self.tokenizer.get_vocabulary())
14
+
15
+ self.index_to_word = tf.keras.layers.StringLookup(
16
+ mask_token="",
17
+ vocabulary=self.tokenizer.get_vocabulary(),
18
+ invert=True)
19
+
20
+ def load_image(img):
21
+ #img = tf.io.read_file(image_path)
22
+ img = tf.io.decode_jpeg(img, channels=3)
23
+ img = tf.image.resize(img, IMAGE_SHAPE[:-1])
24
+ return img
25
+
26
+ def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
27
+ """
28
+ Args:
29
+ inputs (:obj:`PIL.Image`):
30
+ The raw image representation as PIL.
31
+ No transformation made whatsoever from the input. Make all necessary transformations here.
32
+ Return:
33
+ A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
34
+ It is preferred if the returned list is in decreasing `score` order
35
+ """
36
+ image = load_image(inputs)
37
+ initial = self.word_to_index([['[START]']]) # (batch, sequence)
38
+ img_features = self.feature_extractor(image[tf.newaxis, ...])
39
+ temperature = 0
40
+ tokens = initial # (batch, sequence)
41
+ for n in range(50):
42
+ preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)
43
+ preds = preds[:,-1, :] #(batch, vocab)
44
+ if temperature==0:
45
+ next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
46
+ else:
47
+ next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
48
+ tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
49
+
50
+ if next[0] == self.word_to_index('[END]'):
51
+ break
52
+ words = self.index_to_word(tokens[0, 1:-1])
53
+ result = tf.strings.reduce_join(words, axis=-1, separator=' ')
54
+ return result.numpy().decode()