Spaces:
Runtime error
Runtime error
File size: 1,638 Bytes
5f69179 06f1f85 47e33db 06f1f85 6f9105d 06f1f85 6f9105d e140fd2 06f1f85 6f9105d 06f1f85 e140fd2 06f1f85 e140fd2 06f1f85 eb0524f 06f1f85 7a0bbe6 06f1f85 7a0bbe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import re
import string
import gradio as gr
import tensorflow as tf
from load_model import build
IMG_SHAPE = (224,224,3)
def custom_standardization(s):
s = tf.strings.lower(s)
s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
s = tf.strings.join(['[START]', s, '[END]'], separator=' ')
return s
model = build()
rescale = lambda image : tf.image.resize(tf.convert_to_tensor(image), IMG_SHAPE[:-1])
def single_img_transcribe(image, temperature=1):
initial = model.word_to_index([['[START]']]) # (batch, sequence)
img_features = model.feature_extractor(rescale(image)[tf.newaxis, ...])
tokens = initial # (batch, sequence)
for n in range(50):
preds = model((img_features, tokens)).numpy() # (batch, sequence, vocab)
preds = preds[:,-1, :] #(batch, vocab)
if temperature==0:
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
else:
next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
if next[0] == model.word_to_index('[END]'):
break
words = model.index_to_word(tokens[0, 1:-1])
result = tf.strings.reduce_join(words, axis=-1, separator=' ')
return result.numpy().decode()
def img_transcribes(image):
result = []
for t in [0,0.5,1]:
result.append(single_img_transcribe(image, t))
return result
gr.Interface(fn=img_transcribes,
inputs=gr.Image(type="pil"),
outputs=["text","text","text"]
).launch()
|