Arijit-hazra commited on
Commit
e140fd2
1 Parent(s): 9091369

upd app.py / changing the function names again and removing the call

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -14,13 +14,13 @@ def custom_standardization(s):
14
 
15
  model = build()
16
 
17
- def single_transcribe(image, temperature=1):
18
  initial = model.word_to_index([['[START]']]) # (batch, sequence)
19
  img_features = model.feature_extractor(image[tf.newaxis, ...])
20
 
21
  tokens = initial # (batch, sequence)
22
  for n in range(50):
23
- preds = model.call((img_features, tokens)).numpy() # (batch, sequence, vocab)
24
  preds = preds[:,-1, :] #(batch, vocab)
25
  if temperature==0:
26
  next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
@@ -35,13 +35,13 @@ def single_transcribe(image, temperature=1):
35
  result = tf.strings.reduce_join(words, axis=-1, separator=' ')
36
  return result.numpy().decode()
37
 
38
- def transcribes(image):
39
  result = []
40
  for t in [0,0.5,1]:
41
  result.append(single_transcribe(image, t))
42
  return result
43
 
44
- gr.interface(fn=transcribes,
45
  inputs=gr.Image(type="pil"),
46
  outputs=["text","text","text"]
47
  ).launch()
 
14
 
15
  model = build()
16
 
17
+ def single_img_transcribe(image, temperature=1):
18
  initial = model.word_to_index([['[START]']]) # (batch, sequence)
19
  img_features = model.feature_extractor(image[tf.newaxis, ...])
20
 
21
  tokens = initial # (batch, sequence)
22
  for n in range(50):
23
+ preds = model((img_features, tokens)).numpy() # (batch, sequence, vocab)
24
  preds = preds[:,-1, :] #(batch, vocab)
25
  if temperature==0:
26
  next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
 
35
  result = tf.strings.reduce_join(words, axis=-1, separator=' ')
36
  return result.numpy().decode()
37
 
38
+ def img_transcribes(image):
39
  result = []
40
  for t in [0,0.5,1]:
41
  result.append(single_transcribe(image, t))
42
  return result
43
 
44
+ gr.interface(fn=img_transcribes,
45
  inputs=gr.Image(type="pil"),
46
  outputs=["text","text","text"]
47
  ).launch()