reichenbach commited on
Commit
5af1bb1
β€’
1 Parent(s): 9d8f891

Gradio App Code

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ import tensorflow_hub as hub
6
+ import tensorflow_text as text
7
+ from tensorflow import keras
8
+ import gradio as gr
9
+
10
+
11
+ def make_bert_preprocessing_model(sentence_features, seq_length=128):
12
+ """Returns Model mapping string features to BERT inputs.
13
+
14
+ Args:
15
+ sentence_features: A list with the names of string-valued features.
16
+ seq_length: An integer that defines the sequence length of BERT inputs.
17
+
18
+ Returns:
19
+ A Keras Model that can be called on a list or dict of string Tensors
20
+ (with the order or names, resp., given by sentence_features) and
21
+ returns a dict of tensors for input to BERT.
22
+ """
23
+
24
+ input_segments = [
25
+ tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
26
+ for ft in sentence_features
27
+ ]
28
+
29
+ # tokenize the text to word pieces
30
+ bert_preprocess = hub.load(bert_preprocess_path)
31
+ tokenizer = hub.KerasLayer(bert_preprocess.tokenize,
32
+ name="tokenizer")
33
+
34
+ segments = [tokenizer(s) for s in input_segments]
35
+
36
+ truncated_segments = segments
37
+
38
+ packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
39
+ arguments=dict(seq_length=seq_length),
40
+ name="packer")
41
+ model_inputs = packer(truncated_segments)
42
+ return keras.Model(input_segments, model_inputs)
43
+
44
+
45
+ def preprocess_image(image_path, resize):
46
+ extension = tf.strings.split(image_path)[-1]
47
+
48
+ image = tf.io.read_file(image_path)
49
+ if extension == b"jpg":
50
+ image = tf.image.decode_jpeg(image, 3)
51
+ else:
52
+ image = tf.image.decode_png(image, 3)
53
+
54
+ image = tf.image.resize(image, resize)
55
+ return image
56
+
57
+ def preprocess_text(text_1, text_2):
58
+
59
+ text_1 = tf.convert_to_tensor([text_1])
60
+ text_2 = tf.convert_to_tensor([text_2])
61
+
62
+ output = bert_preprocess_model([text_1, text_2])
63
+
64
+ output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
65
+
66
+ return output
67
+
68
+ def preprocess_text_and_image(sample, resize):
69
+
70
+ image_1 = preprocess_image(sample['image_1_path'], resize)
71
+ image_2 = preprocess_image(sample['image_2_path'], resize)
72
+
73
+ text = preprocess_text(sample['text_1'], sample['text_2'])
74
+
75
+ return {"image_1": image_1, "image_2": image_2, "text": text}
76
+
77
+
78
+ def classify_info(image_1, text_1, image_2, text_2):
79
+
80
+ sample = dict()
81
+ sample['image_1_path'] = image_1
82
+ sample['image_2_path'] = image_2
83
+ sample['text_1'] = text_1
84
+ sample['text_2'] = text_2
85
+
86
+ dataframe = pd.DataFrame(sample, index=[0])
87
+
88
+ ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), [0]))
89
+ ds = ds.map(lambda x, y: (preprocess_text_and_image(x, resize), y)).cache()
90
+ batch_size = 1
91
+ auto = tf.data.AUTOTUNE
92
+ ds = ds.batch(batch_size).prefetch(auto)
93
+ output = model.predict(ds)
94
+
95
+ label = np.argmax(output)
96
+ return labels[label]
97
+
98
+
99
+ model = from_pretrained_keras("keras-io/multimodal-entailment")
100
+ resize = (128, 128)
101
+ bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]
102
+ bert_model_path = ("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1")
103
+ bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
104
+ bert_preprocess_model = make_bert_preprocessing_model(['text_1', 'text_2'])
105
+
106
+ labels = {0: "Contradictory", 1: "Implies", 2: "No Entailment"}
107
+
108
+ resize = (128, 128)
109
+ image_1 = gr.inputs.Image(type="filepath")
110
+ image_2 = gr.inputs.Image(type="filepath")
111
+
112
+ text_1 = gr.inputs.Textbox(lines=5)
113
+ text_2 = gr.inputs.Textbox(lines=5)
114
+
115
+ label = gr.outputs.Label()
116
+
117
+ iface = gr.Interface(classify_info,
118
+ inputs=[image_1, text_1, image_2, text_2],outputs=label)
119
+
120
+ iface.launch()