File size: 1,695 Bytes
284eba0
 
 
3e7dbba
284eba0
0de8536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284eba0
0de8536
 
 
 
 
 
284eba0
0de8536
284eba0
0de8536
 
15f8afb
0de8536
 
 
 
4e2ea8a
0de8536
4e2ea8a
 
 
284eba0
0de8536
 
4e2ea8a
cbb0a6b
 
0de8536
 
cbb0a6b
0de8536
cbb0a6b
3e7dbba
284eba0
4191bbd
 
 
 
94ad081
 
4191bbd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import tensorflow as tf
import gdown
from PIL import Image

input_shape = (32, 32, 3)
resized_shape = (224, 224, 3)
num_classes = 10
labels = {
    0: "plane",
    1: "car",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

# Download the model file
def download_model():
    url = "https://drive.google.com/uc?id=12700bE-pomYKoVQ214VrpBoJ7akXcTpL"
    output = "modelV2Lmixed.keras"
    gdown.download(url, output, quiet=False)
    return output

model_file = download_model()

# Load the model
model = tf.keras.models.load_model(model_file)

# Perform image classification
def predict_class(image):
    img = tf.cast(image, tf.float32)
    img = tf.image.resize(img, [input_shape[0], input_shape[1]])
    img = tf.expand_dims(img, axis=0)
    prediction = model.predict(img)
    class_index = tf.argmax(prediction[0]).numpy()
    predicted_class = labels[class_index]
    return predicted_class

# UI Design
def classify_image(image):
    predicted_class = predict_class(image)
    output = f"<h2>Predicted Class:</h2><p>{predicted_class}</p>"
    return output

inputs = gr.inputs.Image(label="Upload an image")
outputs = gr.outputs.HTML()

title = "<h1 style='text-align: center;'>Image Classifier</h1>"
description = "Upload an image and get the predicted class."

gr.Interface(fn=classify_image, 
             inputs=inputs, 
             outputs=outputs, 
             title=title, 
             examples=["00_plane.jpg", "01_car.jpg", "02_bird.jpg", "03_cat.jpg", "04_deer.jpg"], 
             css="body {background-image: url('file=wave.mp4')}",
             description=description).launch()