AIOmarRehan commited on
Commit
1c8dc00
·
verified ·
1 Parent(s): 892909e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +61 -23
model.py CHANGED
@@ -1,32 +1,70 @@
1
- import tensorflow as tf
2
- import numpy as np
3
  from PIL import Image
4
- import os
 
5
 
6
- MODEL_PATH = os.path.join(
7
- os.path.dirname(__file__),
8
- "saved_model",
9
- "Inception_V3_Animals_Classification.h5"
10
- )
11
 
12
- model = tf.keras.models.load_model(MODEL_PATH)
13
 
14
- CLASS_NAMES = ["Cat", "Dog", "Snake"]
 
 
15
 
16
- def preprocess_image(img: Image.Image, target_size=(256, 256)):
17
- img = img.convert("RGB")
18
- img = img.resize(target_size)
19
- img = np.array(img).astype("float32") / 255.0
20
- img = np.expand_dims(img, axis=0)
21
- return img
22
 
23
- def predict(img: Image.Image):
24
- input_tensor = preprocess_image(img)
25
- preds = model.predict(input_tensor)[0]
 
 
26
 
27
- class_idx = int(np.argmax(preds))
28
- confidence = float(np.max(preds))
29
 
30
- prob_dict = {CLASS_NAMES[i]: float(preds[i]) for i in range(len(CLASS_NAMES))}
 
31
 
32
- return CLASS_NAMES[class_idx], confidence, prob_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
  from PIL import Image
4
+ from model import predict
5
+ from datasets import load_dataset
6
 
7
+ # Load dataset (NO streaming → allows len() and indexing)
8
+ dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train")
 
 
 
9
 
10
+ def classify_image(img: Image.Image):
11
 
12
+ # Handle empty input safely
13
+ if img is None:
14
+ return "No image uploaded", 0, {}
15
 
16
+ label, confidence, probs = predict(img)
 
 
 
 
 
17
 
18
+ return (
19
+ label,
20
+ round(confidence, 3),
21
+ {k: round(v, 3) for k, v in probs.items()}
22
+ )
23
 
24
+ # Pick a random example
25
+ def random_example():
26
 
27
+ idx = random.randint(0, len(dataset) - 1)
28
+ item = dataset[idx]
29
 
30
+ img = item["image"].convert("RGB")
31
+ label = item["label"]
32
+ label_str = dataset.features["label"].int2str(label)
33
+
34
+ return img, label_str
35
+
36
+
37
+ demo = gr.Blocks()
38
+
39
+ with demo:
40
+ gr.Markdown("## Animal Image Classifier with Random Dataset Samples")
41
+
42
+ with gr.Row():
43
+ input_img = gr.Image(type="pil", label="Upload an image")
44
+ rand_img = gr.Button("Random Dataset Image")
45
+
46
+ with gr.Row():
47
+ pred_btn = gr.Button("Predict")
48
+
49
+ output_label = gr.Label(label="Predicted Class")
50
+ output_conf = gr.Number(label="Confidence")
51
+ output_probs = gr.JSON(label="All Probabilities")
52
+
53
+ # Display random dataset sample
54
+ rand_display = gr.Image(type="pil", label="Random Dataset Sample")
55
+ rand_label = gr.Textbox(label="Sample Label")
56
+
57
+ # Actions
58
+ pred_btn.click(
59
+ classify_image,
60
+ inputs=input_img,
61
+ outputs=[output_label, output_conf, output_probs]
62
+ )
63
+
64
+ rand_img.click(
65
+ random_example,
66
+ outputs=[rand_display, rand_label]
67
+ )
68
+
69
+ if __name__ == "__main__":
70
+ demo.launch()