F-G Fernandez commited on
Commit
fc36b00
1 Parent(s): a4f6936

fix: Fixed resizing and switched to ONNX

Browse files
Files changed (2) hide show
  1. app.py +53 -21
  2. requirements.txt +5 -2
app.py CHANGED
@@ -1,39 +1,71 @@
 
 
 
 
 
1
  import argparse
 
2
 
3
  import gradio as gr
4
- import torch
 
 
5
  from PIL import Image
6
- from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize
7
- from torchvision.transforms.functional import InterpolationMode
8
 
9
- from holocron import models
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- model = models.rexnet1_3x(pretrained=True).eval()
 
 
 
13
 
14
- preprocessor = Compose([
15
- Resize(model.default_cfg['input_shape'][1:], interpolation=InterpolationMode.BILINEAR),
16
- PILToTensor(),
17
- ConvertImageDtype(torch.float32),
18
- Normalize(model.default_cfg['mean'], model.default_cfg['std'])
19
- ])
20
 
21
- def predict(img):
22
- img = Image.fromarray(img.astype('uint8'), 'RGB')
23
- img = preprocessor(img)
24
- with torch.inference_mode():
25
- prediction = torch.nn.functional.softmax(model(img.unsqueeze(0))[0], dim=0)
26
- return {class_name: float(conf) for class_name, conf in zip(model.default_cfg['classes'], prediction)}
27
 
28
- image = gr.inputs.Image()
29
  outputs = gr.outputs.Label(num_top_classes=3)
30
 
31
  gr.Interface(
32
  fn=predict,
33
- inputs=[image],
34
  outputs=outputs,
35
  title="Holocron: image classification demo",
36
- article=("<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>" "Github Repo</a> | "
37
- "<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>"),
 
 
 
38
  live=True,
39
  ).launch()
 
1
+ # Copyright (C) 2022, François-Guillaume Fernandez.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
5
+
6
  import argparse
7
+ import json
8
 
9
  import gradio as gr
10
+ import numpy as np
11
+ import onnxruntime
12
+ from huggingface_hub import hf_hub_download
13
  from PIL import Image
 
 
14
 
 
15
 
16
+ REPO = "frgfm/rexnet1_0x"
17
+
18
+ # Download model config & checkpoint
19
+ with open(hf_hub_download(args.repo, filename="config.json"), "rb") as f:
20
+ cfg = json.load(f)
21
+
22
+ ort_session = onnxruntime.InferenceSession(hf_hub_download(args.repo, filename="model.onnx"))
23
+
24
+ def preprocess_image(pil_img: Image.Image) -> np.ndarray:
25
+ """Preprocess an image for inference
26
+
27
+ Args:
28
+ pil_img: a valid pillow image
29
+
30
+ Returns:
31
+ the resized and normalized image of shape (1, C, H, W)
32
+ """
33
+
34
+ # Resizing (PIL takes (W, H) order for resizing)
35
+ img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR)
36
+ # (H, W, C) --> (C, H, W)
37
+ img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255
38
+ # Normalization
39
+ img -= np.array(cfg["mean"])[:, None, None]
40
+ img /= np.array(cfg["std"])[:, None, None]
41
+
42
+ return img[None, ...]
43
 
44
+ def predict(image):
45
+ # Preprocessing
46
+ np_img = preprocess_image(image)
47
+ ort_input = {ort_session.get_inputs()[0].name: np_img}
48
 
49
+ # Inference
50
+ ort_out = ort_session.run(None, ort_input)
51
+ # Post-processing
52
+ out_exp = np.exp(ort_out[0][0])
53
+ probs = out_exp / out_exp.sum()
 
54
 
55
+ return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)}
 
 
 
 
 
56
 
57
+ img = gr.inputs.Image(type="pil")
58
  outputs = gr.outputs.Label(num_top_classes=3)
59
 
60
  gr.Interface(
61
  fn=predict,
62
+ inputs=[img],
63
  outputs=outputs,
64
  title="Holocron: image classification demo",
65
+ article=(
66
+ "<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>"
67
+ "Github Repo</a> | "
68
+ "<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>"
69
+ ),
70
  live=True,
71
  ).launch()
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
- -e git+https://github.com/frgfm/Holocron.git#egg=pylocron
2
- gradio>=3.0.2
 
 
 
 
1
+ gradio>=3.0.2,<4.0.0
2
+ Pillow>=8.4.0
3
+ onnxruntime>=1.10.0,<2.0.0
4
+ huggingface-hub>=0.4.0,<1.0.0
5
+ numpy>=1.19.5,<2.0.0