File size: 2,180 Bytes
524506b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b143557
bc5886e
524506b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b143557
524506b
 
 
 
 
 
 
 
 
b143557
524506b
b143557
524506b
 
 
 
 
b143557
524506b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from huggingface_hub import from_pretrained_keras
import tensorflow as tf


CLASSES = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}
IMAGE_SIZE = 32


model = from_pretrained_keras("keras-io/cct")


def reshape_image(image):
    image = tf.convert_to_tensor(image)
    image.set_shape([None, None, 3])
    image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
    image = tf.expand_dims(image, axis=0)
    return image


def classify_image(input_image):
    input_image = reshape_image(input_image)
    logits = model.predict(input_image).flatten()
    predictions = tf.nn.softmax(logits)
    output_labels = {CLASSES[i]: float(predictions[i]) for i in CLASSES.keys()}
    return output_labels


# Gradio Interface
examples = [["./bird.png"], ["./cat.png"], ["./dog.png"], ["./horse.png"]]
title = "Image Classification using Compact Convolutional Transformer (CCT)"
description = """
Upload an image or select one from the examples and ask the model to label it!
<br />
The model was trained on the <a href="https://www.cs.toronto.edu/~kriz/cifar.html" target="_blank">CIFAR-10 dataset</a>. Therefore, it is able to recognise these 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.
<br />
<br />
<p>
    <b>Model:</b> <a href="https://huggingface.co/keras-io/cct" target="_blank">https://huggingface.co/keras-io/cct</a>
    <br />
    <b>Keras Example:</b> <a href="https://keras.io/examples/vision/cct/" target="_blank">https://keras.io/examples/vision/cct/</a>
</p>
<br />
"""
article = """
<div style="text-align: center;">
    Space by <a href="https://github.com/EdAbati" target="_blank">Edoardo Abati</a>
    <br />
    Keras example by <a href="https://twitter.com/RisingSayak" target="_blank">Sayak Paul</a>
</div>
"""

interface = gr.Interface(
    fn=classify_image,
    inputs=gr.inputs.Image(),
    outputs=gr.outputs.Label(),
    examples=examples,
    title=title,
    description=description,
    article=article,
    allow_flagging="never",
)
interface.launch(enable_queue=True)