g4tes commited on
Commit
18742db
·
verified ·
1 Parent(s): f2f4624

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -76
app.py CHANGED
@@ -1,76 +1,86 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import tensorflow as tf
5
- from typing import List, Dict, Any
6
- import io
7
-
8
- # Labels must mirror src/classification-model/index.ts
9
- LABELS: List[str] = [
10
- "battery",
11
- "biological",
12
- "brown-glass",
13
- "cardboard",
14
- "clothes",
15
- "green-glass",
16
- "metal",
17
- "paper",
18
- "plastic",
19
- "shoes",
20
- "trash",
21
- "white-glass",
22
- ]
23
-
24
-
25
- def _load_image_to_rgb(image: Image.Image) -> np.ndarray:
26
- if image.mode != "RGB":
27
- image = image.convert("RGB")
28
- return np.asarray(image)
29
-
30
-
31
- def _resize_224(img_rgb: np.ndarray) -> np.ndarray:
32
- im = Image.fromarray(img_rgb)
33
- im = im.resize((224, 224), Image.NEAREST)
34
- return np.asarray(im)
35
-
36
-
37
- def _preprocess(image: Image.Image) -> np.ndarray:
38
- rgb = _load_image_to_rgb(image)
39
- rgb224 = _resize_224(rgb)
40
- # shape [1,224,224,3], float32 in 0..255
41
- arr = rgb224.astype("float32")
42
- return np.expand_dims(arr, axis=0)
43
-
44
-
45
- class PreTrainedModel:
46
- def __init__(self, model_path: str = "model/model_resnet50.keras") -> None:
47
- self.model = tf.keras.models.load_model(model_path)
48
-
49
- def predict_image(self, image: Image.Image) -> Dict[str, float]:
50
- x = _preprocess(image)
51
- preds = self.model.predict(x)
52
- if isinstance(preds, (list, tuple)):
53
- preds = preds[0]
54
- probs = np.asarray(preds).squeeze().tolist()
55
-
56
- return {label: score for label, score in zip(LABELS, probs)}
57
-
58
-
59
- model = PreTrainedModel()
60
-
61
-
62
- def predict(image):
63
- predictions = model.predict_image(image)
64
- return predictions
65
-
66
-
67
- iface = gr.Interface(
68
- fn=predict,
69
- inputs=gr.Image(type="pil"),
70
- outputs=gr.Label(num_top_classes=3),
71
- title="Waste Classification",
72
- description="Upload an image of waste to classify it.",
73
- )
74
-
75
- if __name__ == "__main__":
76
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tensorflow as tf
5
+ from typing import List, Dict, Any
6
+ import io
7
+
8
+ # Labels must mirror src/classification-model/index.ts
9
+ LABELS: List[str] = [
10
+ "battery",
11
+ "biological",
12
+ "brown-glass",
13
+ "cardboard",
14
+ "clothes",
15
+ "green-glass",
16
+ "metal",
17
+ "paper",
18
+ "plastic",
19
+ "shoes",
20
+ "trash",
21
+ "white-glass",
22
+ ]
23
+
24
+
25
+ def _load_image_to_rgb(image: Image.Image) -> np.ndarray:
26
+ if image.mode != "RGB":
27
+ image = image.convert("RGB")
28
+ return np.asarray(image)
29
+
30
+
31
+ def _resize_224(img_rgb: np.ndarray) -> np.ndarray:
32
+ im = Image.fromarray(img_rgb)
33
+ im = im.resize((224, 224), Image.NEAREST)
34
+ return np.asarray(im)
35
+
36
+
37
+ def _preprocess(image: Image.Image) -> np.ndarray:
38
+ rgb = _load_image_to_rgb(image)
39
+ rgb224 = _resize_224(rgb)
40
+ # shape [1,224,224,3], float32 in 0..255
41
+ arr = rgb224.astype("float32")
42
+ return np.expand_dims(arr, axis=0)
43
+
44
+
45
+ class PreTrainedModel:
46
+ def __init__(self, model_path: str = "model/model_resnet50.keras") -> None:
47
+ self.model = tf.keras.models.load_model(model_path)
48
+
49
+ def predict_image(self, image: Image.Image) -> Dict[str, float]:
50
+ x = _preprocess(image)
51
+ preds = self.model.predict(x)
52
+ if isinstance(preds, (list, tuple)):
53
+ preds = preds[0]
54
+ probs = np.asarray(preds).squeeze().tolist()
55
+
56
+ return {label: score for label, score in zip(LABELS, probs)}
57
+
58
+
59
+ model = PreTrainedModel()
60
+
61
+
62
+ def predict(image):
63
+ predictions = model.predict_image(image)
64
+
65
+ probs_percent = {label: round(p * 100, 2)
66
+ for label, p in predictions.items()}
67
+
68
+ max_label = max(probs_percent, key=probs_percent.get)
69
+
70
+ return {
71
+ "label": max_label,
72
+ "percentage": probs_percent[max_label],
73
+ "probabilities": probs_percent,
74
+ }
75
+
76
+
77
+ iface = gr.Interface(
78
+ fn=predict,
79
+ inputs=gr.Image(type="pil"),
80
+ outputs=gr.JSON(),
81
+ title="Waste Classification",
82
+ description="Upload an image of waste to classify it.",
83
+ )
84
+
85
+ if __name__ == "__main__":
86
+ iface.launch()