File size: 3,861 Bytes
7beb196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b55b1ea
 
7beb196
 
8ceafbd
7beb196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ceafbd
 
7beb196
 
 
 
 
8ceafbd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from PIL import Image
import numpy as np
import torch
from torchvision import transforms, models
from onnx import numpy_helper
import os
import onnxruntime as rt
from matplotlib.colors import hsv_to_rgb
import cv2
import gradio as gr

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
# based on the build flags) when instantiating InferenceSession.
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
os.system("wget https://github.com/AK391/models/raw/main/vision/object_detection_segmentation/fcn/model/fcn-resnet101-11.onnx")
sess = rt.InferenceSession("fcn-resnet101-11.onnx")

outputs = sess.get_outputs()


classes = [line.rstrip('\n') for line in open('voc_classes.txt')]
num_classes = len(classes)

def get_palette():
    # prepare and return palette
    palette = [0] * num_classes * 3
    
    for hue in range(num_classes):
        if hue == 0: # Background color
            colors = (0, 0, 0)
        else:
            colors = hsv_to_rgb((hue / num_classes, 0.75, 0.75))
            
        for i in range(3):
            palette[hue * 3 + i] = int(colors[i] * 255)
            
    return palette

def colorize(labels):
    # generate colorized image from output labels and color palette
    result_img = Image.fromarray(labels).convert('P', colors=num_classes)
    result_img.putpalette(get_palette())
    return np.array(result_img.convert('RGB'))

def visualize_output(image, output):
    assert(image.shape[0] == output.shape[1] and \
           image.shape[1] == output.shape[2]) # Same height and width
    assert(output.shape[0] == num_classes)
    
    # get classification labels
    raw_labels = np.argmax(output, axis=0).astype(np.uint8)

    # comput confidence score
    confidence = float(np.max(output, axis=0).mean())

    # generate segmented image
    result_img = colorize(raw_labels)
    
    # generate blended image
    blended_img = cv2.addWeighted(image[:, :, ::-1], 0.5, result_img, 0.5, 0)
    
    result_img = Image.fromarray(result_img)
    blended_img = Image.fromarray(blended_img)

    return confidence, result_img, blended_img, raw_labels

def inference(img):
  input_image = Image.open(img)
  orig_tensor = np.asarray(input_image)
  input_tensor = preprocess(input_image)
  input_tensor = input_tensor.unsqueeze(0)
  input_tensor = input_tensor.detach().cpu().numpy()
  output_names = list(map(lambda output: output.name, outputs))
  input_name = sess.get_inputs()[0].name
  detections = sess.run(output_names, {input_name: input_tensor})
  output, aux = detections
  conf, result_img, blended_img, _ = visualize_output(orig_tensor, output[0])
  return blended_img
  
title="Fully Convolutional Network"
description="FCNs are a model for real-time neural network for class-wise image segmentation. As the name implies, every weight layer in the network is convolutional. The final layer has the same height/width as the input image, making FCNs a useful tool for doing dense pixel-wise predictions without a significant amount of postprocessing. Being fully convolutional also provides great flexibility in the resolutions this model can handle. This specific model detects 20 different classes. The models have been pre-trained on the COCO train2017 dataset on this class subset."
gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil"),title=title,description=description).launch(enable_queue=True)