import gradio as gr import tensorflow as tf from huggingface_hub import from_pretrained_keras import numpy as np adamatch_model = from_pretrained_keras("keras-io/adamatch-domain-adaption") base_model = from_pretrained_keras("johko/wideresnet28-2-mnist") labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] def predict_image(image, model): image = tf.constant(image) image = tf.reshape(image, [-1, 32, 32, 3]) probs_ada_mnist = model.predict(image)[0,:] top_pred = probs_ada_mnist.tolist() return {labels[i]: top_pred[i] for i in range(10)} def infer(mnist_img, svhn_img, model): labels_out = [] for im in [mnist_img, svhn_img]: labels_out.append(predict_image(im, model)) return labels_out def infer_ada(mnist_image, svhn_image): return infer(mnist_image, svhn_image, adamatch_model) def infer_base(mnist_image, svhn_image): return infer(mnist_image, svhn_image, base_model) def infer_all(mnist_image, svhn_image): base_res = infer_base(mnist_image, svhn_image) ada_res = infer_ada(mnist_image, svhn_image) return base_res.extend(ada_res) article = """
Authors: Johannes Kolbe based on an example by [Sayak Paul](https://twitter.com/RisingSayak) on **keras.io**""" description = """
This space lets you compare image classification results of identical architecture (WideResNet-2-28) models. The training of one of the models was improved by using AdaMatch as seen in the example on [keras.io](https://keras.io/examples/vision/adamatch/). The base model was only trained on the MNIST dataset and shows a low classification accuracy (8.96%) for a different domain dataset like SVHN. The AdaMatch model uses a semi-supervised domain adaption approach to adapt to the SVHN dataset and shows a significantly higher accuracy (26.51%). """ mnist_image_base = gr.inputs.Image(shape=(32, 32)) svhn_image_base = gr.inputs.Image(shape=(32, 32)) mnist_image_ada = gr.inputs.Image(shape=(32, 32)) svhn_image_ada = gr.inputs.Image(shape=(32, 32)) label_mnist_base = gr.outputs.Label(num_top_classes=3, label="MNIST Prediction Base") label_svhn_base = gr.outputs.Label(num_top_classes=3, label="SVHN Prediction Base") label_mnist_ada = gr.outputs.Label(num_top_classes=3, label="MNIST Prediction AdaMatch") label_svhn_ada = gr.outputs.Label(num_top_classes=3, label="SVHN Prediction AdaMatch") base_iface = gr.Interface( fn=infer_base, inputs=[mnist_image_base, svhn_image_base], outputs=[label_mnist_base,label_svhn_base] ) ada_iface = gr.Interface( fn=infer_ada, inputs=[mnist_image_ada, svhn_image_ada], outputs=[label_mnist_ada,label_svhn_ada] ) gr.Parallel(base_iface, ada_iface, examples=[ ["examples/mnist_3.jpg", "examples/svhn_3.jpeg"], ["examples/mnist_8.jpg", "examples/svhn_8.jpg"] ], title="Semi-Supervised Domain Adaption with AdaMatch", article=article, description=description, ).launch()