import re import string import gradio as gr import tensorflow as tf from load_model import build IMG_SHAPE = (224,224,3) def custom_standardization(s): s = tf.strings.lower(s) s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '') s = tf.strings.join(['[START]', s, '[END]'], separator=' ') return s model = build() rescale = lambda image : tf.image.resize(tf.convert_to_tensor(image), IMG_SHAPE[:-1]) def single_img_transcribe(image, temperature=1): initial = model.word_to_index([['[START]']]) # (batch, sequence) img_features = model.feature_extractor(rescale(image)[tf.newaxis, ...]) tokens = initial # (batch, sequence) for n in range(50): preds = model((img_features, tokens)).numpy() # (batch, sequence, vocab) preds = preds[:,-1, :] #(batch, vocab) if temperature==0: next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1) else: next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1) tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) if next[0] == model.word_to_index('[END]'): break words = model.index_to_word(tokens[0, 1:-1]) result = tf.strings.reduce_join(words, axis=-1, separator=' ') return result.numpy().decode() def img_transcribes(image): result = [] for t in [0,0.5,1]: result.append(single_img_transcribe(image, t)) return result gr.Interface(fn=img_transcribes, inputs=gr.Image(type="pil"), outputs=["text","text","text"] ).launch()