Update pipeline.py
Browse files- pipeline.py +7 -4
pipeline.py
CHANGED
@@ -8,13 +8,16 @@ from PIL import Image
|
|
8 |
|
9 |
class PreTrainedPipeline():
|
10 |
def __init__(self, path=""):
|
|
|
|
|
|
|
11 |
self.word_to_index = tf.keras.layers.StringLookup(
|
12 |
mask_token="",
|
13 |
-
vocabulary=self.tokenizer.get_vocabulary())
|
14 |
|
15 |
self.index_to_word = tf.keras.layers.StringLookup(
|
16 |
mask_token="",
|
17 |
-
vocabulary=self.tokenizer.get_vocabulary(),
|
18 |
invert=True)
|
19 |
|
20 |
def load_image(img):
|
@@ -35,11 +38,11 @@ class PreTrainedPipeline():
|
|
35 |
"""
|
36 |
image = load_image(inputs)
|
37 |
initial = self.word_to_index([['[START]']]) # (batch, sequence)
|
38 |
-
img_features = self.feature_extractor(image[tf.newaxis, ...])
|
39 |
temperature = 0
|
40 |
tokens = initial # (batch, sequence)
|
41 |
for n in range(50):
|
42 |
-
preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)
|
43 |
preds = preds[:,-1, :] #(batch, vocab)
|
44 |
if temperature==0:
|
45 |
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
|
|
|
8 |
|
9 |
class PreTrainedPipeline():
|
10 |
def __init__(self, path=""):
|
11 |
+
|
12 |
+
self.model = keras.models.load_model(os.path.join(path, ""))
|
13 |
+
|
14 |
self.word_to_index = tf.keras.layers.StringLookup(
|
15 |
mask_token="",
|
16 |
+
vocabulary=self.model.tokenizer.get_vocabulary())
|
17 |
|
18 |
self.index_to_word = tf.keras.layers.StringLookup(
|
19 |
mask_token="",
|
20 |
+
vocabulary=self.model.tokenizer.get_vocabulary(),
|
21 |
invert=True)
|
22 |
|
23 |
def load_image(img):
|
|
|
38 |
"""
|
39 |
image = load_image(inputs)
|
40 |
initial = self.word_to_index([['[START]']]) # (batch, sequence)
|
41 |
+
img_features = self.model.feature_extractor(image[tf.newaxis, ...])
|
42 |
temperature = 0
|
43 |
tokens = initial # (batch, sequence)
|
44 |
for n in range(50):
|
45 |
+
preds = self.model((img_features, tokens)).numpy() # (batch, sequence, vocab)
|
46 |
preds = preds[:,-1, :] #(batch, vocab)
|
47 |
if temperature==0:
|
48 |
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
|