| import random |
| import gradio as gr |
| from PIL import Image |
| from model import predict |
| from datasets import load_dataset |
|
|
| dataset = load_dataset( |
| "AIOmarRehan/AnimalsDataset", |
| split="train", |
| streaming=True |
| ) |
|
|
| def classify_image(img: Image.Image): |
|
|
| if img is None: |
| return "No image uploaded", 0, {} |
|
|
| label, confidence, probs = predict(img) |
|
|
| return ( |
| label, |
| round(confidence, 3), |
| {k: round(v, 3) for k, v in probs.items()} |
| ) |
|
|
| def random_example(): |
|
|
| item = next(iter(dataset.shuffle(buffer_size=1500))) |
|
|
| img = item["image"].convert("RGB") |
| label = item["label"] |
|
|
| label_str = dataset.features["label"].int2str(label) |
|
|
| return img, img, label_str |
|
|
|
|
| demo = gr.Blocks() |
|
|
| with demo: |
| gr.Markdown("## Animal Image Classifier with Random Dataset Samples") |
| |
| with gr.Row(): |
| input_img = gr.Image(type="pil", label="Upload an image") |
| rand_img = gr.Button("Random Dataset Image") |
| |
| pred_btn = gr.Button("Predict") |
| |
| output_label = gr.Label(label="Predicted Class") |
| output_conf = gr.Number(label="Confidence") |
| output_probs = gr.JSON(label="All Probabilities") |
| |
| rand_display = gr.Image(type="pil", label="Random Dataset Sample") |
| rand_label = gr.Textbox(label="Sample Label") |
| |
| pred_btn.click( |
| classify_image, |
| inputs=input_img, |
| outputs=[output_label, output_conf, output_probs] |
| ) |
|
|
| rand_img.click( |
| random_example, |
| outputs=[input_img, rand_display, rand_label] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |