File size: 1,638 Bytes
5f69179
 
06f1f85
47e33db
06f1f85
 
6f9105d
06f1f85
 
 
 
 
 
 
 
 
 
6f9105d
 
e140fd2
06f1f85
6f9105d
06f1f85
 
 
e140fd2
06f1f85
 
 
 
 
 
 
 
 
 
 
 
 
 
e140fd2
06f1f85
 
eb0524f
06f1f85
 
2fb8345
06f1f85
 
ccd1fd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import re
import string
import gradio as gr
import tensorflow as tf
from load_model import build

IMG_SHAPE = (224,224,3)


def custom_standardization(s):
    s = tf.strings.lower(s)
    s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
    s = tf.strings.join(['[START]', s, '[END]'], separator=' ')
    return s
    
model = build()

rescale = lambda image : tf.image.resize(tf.convert_to_tensor(image), IMG_SHAPE[:-1])

def single_img_transcribe(image, temperature=1):
    initial = model.word_to_index([['[START]']]) # (batch, sequence)
    img_features = model.feature_extractor(rescale(image)[tf.newaxis, ...])

    tokens = initial # (batch, sequence)
    for n in range(50):
        preds = model((img_features, tokens)).numpy()  # (batch, sequence, vocab)
        preds = preds[:,-1, :]  #(batch, vocab)
        if temperature==0:
            next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
        else:
            next = tf.random.categorical(preds/temperature, num_samples=1)  # (batch, 1)
        tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) 

        if next[0] == model.word_to_index('[END]'):
            break
            
    words = model.index_to_word(tokens[0, 1:-1])
    result = tf.strings.reduce_join(words, axis=-1, separator=' ')
    return result.numpy().decode()

def img_transcribes(image):
    result = []
    for t in [0,0.5,1]:
        result.append(single_img_transcribe(image, t))
    return result
    
gr.Interface(fn=img_transcribes, 
             inputs=gr.Image(type="pil"),
             outputs=["text","text","text"]
            ).launch()