Arijit-hazra commited on
Commit
06f1f85
1 Parent(s): 3d70562

Update app.py

Browse files

This is initial app

Files changed (1) hide show
  1. app.py +44 -0
app.py CHANGED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from load_model import build
3
+
4
+
5
+
6
+ def custom_standardization(s):
7
+ s = tf.strings.lower(s)
8
+ s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
9
+ s = tf.strings.join(['[START]', s, '[END]'], separator=' ')
10
+ return s
11
+
12
+ model = build()
13
+
14
+ def simple_gen(image, temperature=1):
15
+ initial = model.word_to_index([['[START]']]) # (batch, sequence)
16
+ img_features = model.feature_extractor(image[tf.newaxis, ...])
17
+
18
+ tokens = initial # (batch, sequence)
19
+ for n in range(50):
20
+ preds = model((img_features, tokens)).numpy() # (batch, sequence, vocab)
21
+ preds = preds[:,-1, :] #(batch, vocab)
22
+ if temperature==0:
23
+ next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
24
+ else:
25
+ next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
26
+ tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
27
+
28
+ if next[0] == model.word_to_index('[END]'):
29
+ break
30
+
31
+ words = model.index_to_word(tokens[0, 1:-1])
32
+ result = tf.strings.reduce_join(words, axis=-1, separator=' ')
33
+ return result.numpy().decode()
34
+
35
+ def transcribes(image):
36
+ result = []
37
+ for t in [0,0.5,1]:
38
+ result.append(simple_gen(image, t))
39
+ return result
40
+
41
+ gr.interface(fn=transcribes,
42
+ inputs=gr.Image(type="pil"),
43
+ outputs=["text","text","text"]
44
+ ).launch()