Create pipeline.py
Browse files- pipeline.py +54 -0
pipeline.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
from tensorflow import keras
|
4 |
+
import os
|
5 |
+
from typing import Dict, List, Any
|
6 |
+
import pickle
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
class PreTrainedPipeline():
|
10 |
+
def __init__(self, path=""):
|
11 |
+
self.word_to_index = tf.keras.layers.StringLookup(
|
12 |
+
mask_token="",
|
13 |
+
vocabulary=self.tokenizer.get_vocabulary())
|
14 |
+
|
15 |
+
self.index_to_word = tf.keras.layers.StringLookup(
|
16 |
+
mask_token="",
|
17 |
+
vocabulary=self.tokenizer.get_vocabulary(),
|
18 |
+
invert=True)
|
19 |
+
|
20 |
+
def load_image(img):
|
21 |
+
#img = tf.io.read_file(image_path)
|
22 |
+
img = tf.io.decode_jpeg(img, channels=3)
|
23 |
+
img = tf.image.resize(img, IMAGE_SHAPE[:-1])
|
24 |
+
return img
|
25 |
+
|
26 |
+
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
inputs (:obj:`PIL.Image`):
|
30 |
+
The raw image representation as PIL.
|
31 |
+
No transformation made whatsoever from the input. Make all necessary transformations here.
|
32 |
+
Return:
|
33 |
+
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
|
34 |
+
It is preferred if the returned list is in decreasing `score` order
|
35 |
+
"""
|
36 |
+
image = load_image(inputs)
|
37 |
+
initial = self.word_to_index([['[START]']]) # (batch, sequence)
|
38 |
+
img_features = self.feature_extractor(image[tf.newaxis, ...])
|
39 |
+
temperature = 0
|
40 |
+
tokens = initial # (batch, sequence)
|
41 |
+
for n in range(50):
|
42 |
+
preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)
|
43 |
+
preds = preds[:,-1, :] #(batch, vocab)
|
44 |
+
if temperature==0:
|
45 |
+
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
|
46 |
+
else:
|
47 |
+
next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
|
48 |
+
tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
|
49 |
+
|
50 |
+
if next[0] == self.word_to_index('[END]'):
|
51 |
+
break
|
52 |
+
words = self.index_to_word(tokens[0, 1:-1])
|
53 |
+
result = tf.strings.reduce_join(words, axis=-1, separator=' ')
|
54 |
+
return result.numpy().decode()
|