| import gradio as gr |
| from PIL import Image |
| from model import predict |
|
|
| def classify_image(img: Image.Image): |
| label, confidence, probs = predict(img) |
|
|
| return ( |
| label, |
| round(confidence, 3), |
| {k: round(v, 3) for k, v in probs.items()} |
| ) |
|
|
| demo = gr.Interface( |
| fn=classify_image, |
| inputs=gr.Image(type="pil", label="Upload an image"), |
| outputs=[ |
| gr.Label(label="Predicted Class"), |
| gr.Number(label="Confidence"), |
| gr.JSON(label="All Probabilities") |
| ], |
| title="Animal Image Classifier", |
| description="Upload an image and the model will predict the animal." |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |