frgfm commited on
Commit
04b53ce
1 Parent(s): f2f8458

feat: Added Gradio demo

Browse files
Files changed (2) hide show
  1. app.py +70 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022, Pyronear.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
5
+
6
+ import argparse
7
+ import json
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import onnxruntime
12
+ from huggingface_hub import hf_hub_download
13
+ from PIL import Image
14
+
15
+
16
+ # Download model config & checkpoint
17
+ with open(hf_hub_download(args.repo, filename="config.json"), "rb") as f:
18
+ cfg = json.load(f)
19
+
20
+ ort_session = onnxruntime.InferenceSession(hf_hub_download(args.repo, filename="model.onnx"))
21
+
22
+ def preprocess_image(pil_img: Image.Image) -> np.ndarray:
23
+ """Preprocess an image for inference
24
+
25
+ Args:
26
+ pil_img: a valid pillow image
27
+
28
+ Returns:
29
+ the resized and normalized image of shape (1, C, H, W)
30
+ """
31
+
32
+ # Resizing
33
+ img = pil_img.resize(cfg["input_shape"][-2:], Image.BILINEAR)
34
+ # (H, W, C) --> (C, H, W)
35
+ img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255
36
+ # Normalization
37
+ img -= np.array(cfg["mean"])[:, None, None]
38
+ img /= np.array(cfg["std"])[:, None, None]
39
+
40
+ return img[None, ...]
41
+
42
+ def predict(image):
43
+ # Preprocessing
44
+ np_img = preprocess_image(image)
45
+ ort_input = {ort_session.get_inputs()[0].name: np_img}
46
+
47
+ # Inference
48
+ ort_out = ort_session.run(None, ort_input)
49
+ # Post-processing
50
+ probs = 1 / (1 + np.exp(-ort_out[0][0]))
51
+
52
+ return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)}
53
+
54
+
55
+ img = gr.inputs.Image(type="pil")
56
+ outputs = gr.outputs.Label(num_top_classes=1)
57
+
58
+
59
+ gr.Interface(
60
+ fn=predict,
61
+ inputs=[img],
62
+ outputs=outputs,
63
+ title="PyroVision: image classification demo",
64
+ article=(
65
+ "<p style='text-align: center'><a href='https://github.com/pyronear/pyro-vision'>"
66
+ "Github Repo</a> | "
67
+ "<a href='https://pyronear.org/pyro-vision/'>Documentation</a></p>"
68
+ ),
69
+ live=True,
70
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ gradio>=3.0.2,<4.0.0
2
+ Pillow>=8.4.0
3
+ onnxruntime>=1.10.0,<2.0.0
4
+ huggingface-hub>=0.4.0,<1.0.0
5
+ numpy>=1.19.5,<2.0.0