Arijit-hazra commited on
Commit
6f9105d
1 Parent(s): b8e7778

adding a function to rescale the image

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import tensorflow as tf
5
  from load_model import build
6
 
 
7
 
8
 
9
  def custom_standardization(s):
@@ -14,9 +15,11 @@ def custom_standardization(s):
14
 
15
  model = build()
16
 
 
 
17
  def single_img_transcribe(image, temperature=1):
18
  initial = model.word_to_index([['[START]']]) # (batch, sequence)
19
- img_features = model.feature_extractor(tf.convert_to_tensor(image)[tf.newaxis, ...])
20
 
21
  tokens = initial # (batch, sequence)
22
  for n in range(50):
 
4
  import tensorflow as tf
5
  from load_model import build
6
 
7
+ IMG_SHAPE = (224,224,3)
8
 
9
 
10
  def custom_standardization(s):
 
15
 
16
  model = build()
17
 
18
+ rescale = lambda image : tf.image.resize(tf.convert_to_tensor(image), IMG_SHAPE[:-1])
19
+
20
  def single_img_transcribe(image, temperature=1):
21
  initial = model.word_to_index([['[START]']]) # (batch, sequence)
22
+ img_features = model.feature_extractor(rescale(image)[tf.newaxis, ...])
23
 
24
  tokens = initial # (batch, sequence)
25
  for n in range(50):