| import tensorflow as tf |
| import numpy as np |
| import gradio as gr |
| from PIL import Image |
|
|
| |
| model = tf.keras.models.load_model("cifar10_custom_cnn.keras") |
|
|
| |
| class_names = [ |
| "Airplane", "Automobile", "Bird", "Cat", "Deer", |
| "Dog", "Frog", "Horse", "Ship", "Truck" |
| ] |
|
|
| def predict(image): |
| image = image.resize((32, 32)) |
| image = np.array(image) / 255.0 |
| image = image.reshape(1, 32, 32, 3) |
| |
| predictions = model.predict(image) |
| class_index = np.argmax(predictions) |
| |
| return class_names[class_index] |
|
|
| interface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs="label", |
| title="CIFAR-10 Image Classification", |
| description="Custom CNN model trained on CIFAR-10 dataset" |
| ) |
|
|
| interface.launch() |
|
|