Johannes Kolbe commited on
Commit
b28a8cb
1 Parent(s): 0181e70

working space

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
1
+ venv
2
+ .mypy_cache
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from huggingface_hub import from_pretrained_keras
4
+ import numpy as np
5
+
6
+ adamatch_model = from_pretrained_keras("johko/adamatch-keras-io")
7
+ base_model = from_pretrained_keras("johko/wideresnet28-2-mnist")
8
+
9
+
10
+ labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
11
+
12
+ def predict_image(image, model):
13
+ image = tf.constant(image)
14
+ image = tf.reshape(image, [-1, 32, 32, 3])
15
+ probs_ada_mnist = model.predict(image)[0,:]
16
+ top_pred = probs_ada_mnist.tolist()
17
+ return {labels[i]: top_pred[i] for i in range(10)}
18
+
19
+ def infer(mnist_img, svhn_img, model):
20
+ labels_out = []
21
+ for im in [mnist_img, svhn_img]:
22
+ labels_out.append(predict_image(im, model))
23
+ return labels_out
24
+
25
+ def infer_ada(mnist_image, svhn_image):
26
+ return infer(mnist_image, svhn_image, adamatch_model)
27
+
28
+ def infer_base(mnist_image, svhn_image):
29
+ return infer(mnist_image, svhn_image, base_model)
30
+
31
+
32
+ def infer_all(mnist_image, svhn_image):
33
+ base_res = infer_base(mnist_image, svhn_image)
34
+ ada_res = infer_ada(mnist_image, svhn_image)
35
+ return base_res.extend(ada_res)
36
+
37
+ article = """<center>
38
+
39
+ Authors: <a href='https://twitter.com/johko990' target='_blank'>Johannes Kolbe</a> based on an example by [Sayak Paul](https://twitter.com/RisingSayak) on
40
+ <a href='https://keras.io/examples/vision/adamatch/' target='_blank'>**keras.io**</a>"""
41
+
42
+
43
+
44
+ description = """<center>
45
+
46
+ This space lets you compare image classification results of identical architecture (WideResNet-2-28) models. The training of one of the models was improved
47
+ by using AdaMatch as seen in the example on [keras.io](https://keras.io/examples/vision/adamatch/).
48
+
49
+ The base model was only trained on the MNIST dataset and shows a low classification accuracy (8.96%) for a different domain dataset like SVHN. The AdaMatch model
50
+ uses a semi-supervised domain adaption approach to adapt to the SVHN dataset and shows a significantly higher accuracy (26.51%).
51
+ """
52
+ mnist_image_base = gr.inputs.Image(shape=(32, 32))
53
+ svhn_image_base = gr.inputs.Image(shape=(32, 32))
54
+ mnist_image_ada = gr.inputs.Image(shape=(32, 32))
55
+ svhn_image_ada = gr.inputs.Image(shape=(32, 32))
56
+
57
+ label_mnist_base = gr.outputs.Label(num_top_classes=3, label="MNIST Prediction Base")
58
+ label_svhn_base = gr.outputs.Label(num_top_classes=3, label="SVHN Prediction Base")
59
+ label_mnist_ada = gr.outputs.Label(num_top_classes=3, label="MNIST Prediction AdaMatch")
60
+ label_svhn_ada = gr.outputs.Label(num_top_classes=3, label="SVHN Prediction AdaMatch")
61
+
62
+
63
+ base_iface = gr.Interface(
64
+ fn=infer_base,
65
+ inputs=[mnist_image_base, svhn_image_base],
66
+ outputs=[label_mnist_base,label_svhn_base]
67
+ )
68
+
69
+ ada_iface = gr.Interface(
70
+ fn=infer_ada,
71
+ inputs=[mnist_image_ada, svhn_image_ada],
72
+ outputs=[label_mnist_ada,label_svhn_ada]
73
+ )
74
+
75
+ gr.Parallel(base_iface,
76
+ ada_iface,
77
+ examples=[
78
+ ["examples/mnist_3.jpg", "examples/svhn_3.jpeg"],
79
+ ["examples/mnist_8.jpg", "examples/svhn_8.jpg"]
80
+ ],
81
+ title="Domain Adaption with AdaMatch",
82
+ article=article,
83
+ description=description,
84
+ ).launch()
examples/mnist_3.jpg ADDED
examples/mnist_8.jpg ADDED
examples/svhn_3.jpeg ADDED
examples/svhn_8.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ gradio
2
+ tensorflow
3
+ huggingface_hub