import numpy as np import tensorflow as tf import gradio as gr from huggingface_hub import from_pretrained_keras teacher_model = from_pretrained_keras("keras-io/consistency_training_with_supervision_teacher_model") student_model = from_pretrained_keras("keras-io/consistency_training_with_supervision_student_model") class_names = [ "Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck", ] examples = [ ['./aeroplane.png'], ['./horse.png'], ['./ship.png'], ['./truck.png'] ] IMG_SIZE = 72 def teacher_model_output(image_tensor): predictions = teacher_model.predict(np.expand_dims((image_tensor), axis=0)) predictions = np.squeeze(predictions) predictions = np.argmax(predictions) predicted_label = class_names[predictions.item()] return str(predicted_label) def student_model_output(image_tensor): predictions = student_model.predict(np.expand_dims((image_tensor), axis=0)) predictions = np.squeeze(predictions) predictions = np.argmax(predictions) predicted_label = class_names[predictions.item()] return str(predicted_label) def infer(input_image): image_tensor = tf.convert_to_tensor(input_image) image_tensor.set_shape([None, None, 3]) image_tensor = tf.image.resize(image_tensor, (IMG_SIZE, IMG_SIZE)) return teacher_model_output(image_tensor), student_model_output(image_tensor) input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE)) output = [gr.outputs.Label(label = "Teacher Model Output"), gr.outputs.Label(label = "Student Model Output")] title = "Image Classification using Consistency training with supervision" description = "Upload an image or select from examples to classify it.
The allowed classes are - Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck.

Teacher Model Repo - https://huggingface.co/keras-io/consistency_training_with_supervision_teacher_model
Student Model Repo - https://huggingface.co/keras-io/consistency_training_with_supervision_student_model
Keras Example - https://keras.io/examples/vision/consistency_training/

" article = "
Space by Vivek Rai
Keras example by Sayak Paul
" gr_interface = gr.Interface( infer, input, output, examples=examples, allow_flagging=False, analytics_enabled=False, title=title, description=description, article=article).launch(enable_queue=True, debug=True)