reichenbach's picture
Update app.py
21f578d
raw
history blame
4.06 kB
import os
os.system('pip install tensorflow')
os.system('pip install tensorflow_hub')
os.system('pip install tensorflow_text')
from huggingface_hub import from_pretrained_keras
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow import keras
import gradio as gr
def make_bert_preprocessing_model(sentence_features, seq_length=128):
"""Returns Model mapping string features to BERT inputs.
Args:
sentence_features: A list with the names of string-valued features.
seq_length: An integer that defines the sequence length of BERT inputs.
Returns:
A Keras Model that can be called on a list or dict of string Tensors
(with the order or names, resp., given by sentence_features) and
returns a dict of tensors for input to BERT.
"""
input_segments = [
tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
for ft in sentence_features
]
# tokenize the text to word pieces
bert_preprocess = hub.load(bert_preprocess_path)
tokenizer = hub.KerasLayer(bert_preprocess.tokenize,
name="tokenizer")
segments = [tokenizer(s) for s in input_segments]
truncated_segments = segments
packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
arguments=dict(seq_length=seq_length),
name="packer")
model_inputs = packer(truncated_segments)
return keras.Model(input_segments, model_inputs)
def preprocess_image(image_path, resize):
extension = tf.strings.split(image_path)[-1]
image = tf.io.read_file(image_path)
if extension == b"jpg":
image = tf.image.decode_jpeg(image, 3)
else:
image = tf.image.decode_png(image, 3)
image = tf.image.resize(image, resize)
return image
def preprocess_text(text_1, text_2):
text_1 = tf.convert_to_tensor([text_1])
text_2 = tf.convert_to_tensor([text_2])
output = bert_preprocess_model([text_1, text_2])
output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
return output
def preprocess_text_and_image(sample, resize):
image_1 = preprocess_image(sample['image_1_path'], resize)
image_2 = preprocess_image(sample['image_2_path'], resize)
text = preprocess_text(sample['text_1'], sample['text_2'])
return {"image_1": image_1, "image_2": image_2, "text": text}
def classify_info(image_1, text_1, image_2, text_2):
sample = dict()
sample['image_1_path'] = image_1
sample['image_2_path'] = image_2
sample['text_1'] = text_1
sample['text_2'] = text_2
dataframe = pd.DataFrame(sample, index=[0])
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), [0]))
ds = ds.map(lambda x, y: (preprocess_text_and_image(x, resize), y)).cache()
batch_size = 1
auto = tf.data.AUTOTUNE
ds = ds.batch(batch_size).prefetch(auto)
output = model.predict(ds)
label = np.argmax(output)
return labels[label]
model = from_pretrained_keras("keras-io/multimodal-entailment")
resize = (128, 128)
bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]
bert_model_path = ("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1")
bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
bert_preprocess_model = make_bert_preprocessing_model(['text_1', 'text_2'])
labels = {0: "Contradictory", 1: "Implies", 2: "No Entailment"}
resize = (128, 128)
image_1 = gr.inputs.Image(type="filepath")
image_2 = gr.inputs.Image(type="filepath")
text_1 = gr.inputs.Textbox(lines=5)
text_2 = gr.inputs.Textbox(lines=5)
label = gr.outputs.Label()
iface = gr.Interface(classify_info,
inputs=[image_1, text_1, image_2, text_2],outputs=label)
iface.launch()