import gradio as gr from huggingface_hub import from_pretrained_keras import tensorflow as tf CLASSES = { 0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck", } IMAGE_SIZE = 32 model = from_pretrained_keras("keras-io/cct") def reshape_image(image): image = tf.convert_to_tensor(image) image.set_shape([None, None, 3]) image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE]) image = tf.expand_dims(image, axis=0) return image def classify_image(input_image): input_image = reshape_image(input_image) logits = model.predict(input_image).flatten() predictions = tf.nn.softmax(logits) output_labels = {CLASSES[i]: float(predictions[i]) for i in CLASSES.keys()} return output_labels # Gradio Interface examples = [["./bird.png"], ["./cat.png"], ["./dog.png"], ["./horse.png"]] title = "Image Classification using Compact Convolutional Transformer (CCT)" description = """ Upload an image or select one from the examples and ask the model to label it!
The model was trained on the CIFAR-10 dataset. Therefore, it is able to recognise these 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.

Model: https://huggingface.co/keras-io/cct
Keras Example: https://keras.io/examples/vision/cct/


""" article = """
Space by Edoardo Abati
Keras example by Sayak Paul
""" interface = gr.Interface( fn=classify_image, inputs=gr.inputs.Image(), outputs=gr.outputs.Label(), examples=examples, title=title, description=description, article=article, allow_flagging="never", ) interface.launch(enable_queue=True)