File size: 1,330 Bytes
35d6a79
 
159731b
35d6a79
159731b
35d6a79
 
159731b
35d6a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7240b44
35d6a79
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
import numpy as np
import gradio as gr
from PIL import Image

# Load the saved model
model = tf.keras.models.load_model('cifar10_cnn_model.keras')

# Define the CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Define a function to preprocess the input image
def preprocess_image(image):
    image = image.resize((32, 32))  # Resize image to 32x32
    image = np.array(image) / 255.0  # Normalize pixel values
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    return image

# Define the prediction function
def classify_image(image):
    preprocessed_image = preprocess_image(image)
    predictions = model.predict(preprocessed_image)
    predicted_class = class_names[np.argmax(predictions)]
    confidence = np.max(predictions)
    return f"Prediction: {predicted_class} (Confidence: {confidence:.2f})"

# Create the Gradio interface
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="CIFAR-10 Image Classifier",
    description="Upload an image of a CIFAR-10 category (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck), and the model will classify it."
)

# Launch the app
if __name__ == "__main__":
    interface.launch()