hoangthan commited on
Commit
8ab6fdb
1 Parent(s): 7dd3871

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -4
pipeline.py CHANGED
@@ -8,13 +8,16 @@ from PIL import Image
8
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
 
 
 
11
  self.word_to_index = tf.keras.layers.StringLookup(
12
  mask_token="",
13
- vocabulary=self.tokenizer.get_vocabulary())
14
 
15
  self.index_to_word = tf.keras.layers.StringLookup(
16
  mask_token="",
17
- vocabulary=self.tokenizer.get_vocabulary(),
18
  invert=True)
19
 
20
  def load_image(img):
@@ -35,11 +38,11 @@ class PreTrainedPipeline():
35
  """
36
  image = load_image(inputs)
37
  initial = self.word_to_index([['[START]']]) # (batch, sequence)
38
- img_features = self.feature_extractor(image[tf.newaxis, ...])
39
  temperature = 0
40
  tokens = initial # (batch, sequence)
41
  for n in range(50):
42
- preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)
43
  preds = preds[:,-1, :] #(batch, vocab)
44
  if temperature==0:
45
  next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
 
8
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
11
+
12
+ self.model = keras.models.load_model(os.path.join(path, ""))
13
+
14
  self.word_to_index = tf.keras.layers.StringLookup(
15
  mask_token="",
16
+ vocabulary=self.model.tokenizer.get_vocabulary())
17
 
18
  self.index_to_word = tf.keras.layers.StringLookup(
19
  mask_token="",
20
+ vocabulary=self.model.tokenizer.get_vocabulary(),
21
  invert=True)
22
 
23
  def load_image(img):
 
38
  """
39
  image = load_image(inputs)
40
  initial = self.word_to_index([['[START]']]) # (batch, sequence)
41
+ img_features = self.model.feature_extractor(image[tf.newaxis, ...])
42
  temperature = 0
43
  tokens = initial # (batch, sequence)
44
  for n in range(50):
45
+ preds = self.model((img_features, tokens)).numpy() # (batch, sequence, vocab)
46
  preds = preds[:,-1, :] #(batch, vocab)
47
  if temperature==0:
48
  next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)