akhaliq HF staff commited on
Commit
7beb196
1 Parent(s): 5a0ebda

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms, models
5
+ from onnx import numpy_helper
6
+ import os
7
+ import onnxruntime as rt
8
+ from matplotlib.colors import hsv_to_rgb
9
+ import cv2
10
+ import gradio as gr
11
+
12
+ preprocess = transforms.Compose([
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
15
+ ])
16
+
17
+ # Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
18
+ # other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
19
+ # based on the build flags) when instantiating InferenceSession.
20
+ # For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
21
+ # onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
22
+ sess = rt.InferenceSession("../model/fcn-resnet101-11.onnx")
23
+
24
+ outputs = sess.get_outputs()
25
+ output_names = list(map(lambda output: output.name, outputs))
26
+ input_name = sess.get_inputs()[0].name
27
+
28
+ classes = [line.rstrip('\n') for line in open('voc_classes.txt')]
29
+ num_classes = len(classes)
30
+
31
+ def get_palette():
32
+ # prepare and return palette
33
+ palette = [0] * num_classes * 3
34
+
35
+ for hue in range(num_classes):
36
+ if hue == 0: # Background color
37
+ colors = (0, 0, 0)
38
+ else:
39
+ colors = hsv_to_rgb((hue / num_classes, 0.75, 0.75))
40
+
41
+ for i in range(3):
42
+ palette[hue * 3 + i] = int(colors[i] * 255)
43
+
44
+ return palette
45
+
46
+ def colorize(labels):
47
+ # generate colorized image from output labels and color palette
48
+ result_img = Image.fromarray(labels).convert('P', colors=num_classes)
49
+ result_img.putpalette(get_palette())
50
+ return np.array(result_img.convert('RGB'))
51
+
52
+ def visualize_output(image, output):
53
+ assert(image.shape[0] == output.shape[1] and \
54
+ image.shape[1] == output.shape[2]) # Same height and width
55
+ assert(output.shape[0] == num_classes)
56
+
57
+ # get classification labels
58
+ raw_labels = np.argmax(output, axis=0).astype(np.uint8)
59
+
60
+ # comput confidence score
61
+ confidence = float(np.max(output, axis=0).mean())
62
+
63
+ # generate segmented image
64
+ result_img = colorize(raw_labels)
65
+
66
+ # generate blended image
67
+ blended_img = cv2.addWeighted(image[:, :, ::-1], 0.5, result_img, 0.5, 0)
68
+
69
+ result_img = Image.fromarray(result_img)
70
+ blended_img = Image.fromarray(blended_img)
71
+
72
+ return confidence, result_img, blended_img, raw_labels
73
+
74
+ def inference(img):
75
+ input_image = Image.open(img)
76
+ orig_tensor = np.asarray(input_image)
77
+ input_tensor = preprocess(input_image)
78
+ input_tensor = input_tensor.unsqueeze(0)
79
+ input_tensor = input_tensor.detach().cpu().numpy()
80
+
81
+ detections = sess.run(output_names, {input_name: input_tensor})
82
+ output, aux = detections
83
+ conf, result_img, blended_img, _ = visualize_output(orig_tensor, output[0])
84
+ return blended_img
85
+
86
+ gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil")).launch()