faster-rcnn / app.py
akhaliq's picture
akhaliq HF staff
Create app.py
8f5cd7f
raw
history blame
4.49 kB
from PIL import Image
import numpy as np
import torch
from torchvision import transforms, models
from onnx import numpy_helper
import os
import onnxruntime as rt
from matplotlib.colors import hsv_to_rgb
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pycocotools.mask as mask_util
def preprocess(image):
# Resize
ratio = 800.0 / min(image.size[0], image.size[1])
image = image.resize((int(ratio * image.size[0]), int(ratio * image.size[1])), Image.BILINEAR)
# Convert to BGR
image = np.array(image)[:, :, [2, 1, 0]].astype('float32')
# HWC -> CHW
image = np.transpose(image, [2, 0, 1])
# Normalize
mean_vec = np.array([102.9801, 115.9465, 122.7717])
for i in range(image.shape[0]):
image[i, :, :] = image[i, :, :] - mean_vec[i]
# Pad to be divisible of 32
import math
padded_h = int(math.ceil(image.shape[1] / 32) * 32)
padded_w = int(math.ceil(image.shape[2] / 32) * 32)
padded_image = np.zeros((3, padded_h, padded_w), dtype=np.float32)
padded_image[:, :image.shape[1], :image.shape[2]] = image
image = padded_image
return image
# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
# based on the build flags) when instantiating InferenceSession.
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
os.system("wget https://github.com/AK391/models/raw/main/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.onnx")
sess = rt.InferenceSession("MaskRCNN-10.onnx")
outputs = sess.get_outputs()
classes = [line.rstrip('\n') for line in open('coco_classes.txt')]
def display_objdetect_image(image, boxes, labels, scores, masks, score_threshold=0.7):
# Resize boxes
ratio = 800.0 / min(image.size[0], image.size[1])
boxes /= ratio
_, ax = plt.subplots(1, figsize=(12,9))
image = np.array(image)
for mask, box, label, score in zip(masks, boxes, labels, scores):
# Showing boxes with score > 0.7
if score <= score_threshold:
continue
# Finding contour based on mask
mask = mask[0, :, :, None]
int_box = [int(i) for i in box]
mask = cv2.resize(mask, (int_box[2]-int_box[0]+1, int_box[3]-int_box[1]+1))
mask = mask > 0.5
im_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
x_0 = max(int_box[0], 0)
x_1 = min(int_box[2] + 1, image.shape[1])
y_0 = max(int_box[1], 0)
y_1 = min(int_box[3] + 1, image.shape[0])
mask_y_0 = max(y_0 - box[1], 0)
mask_y_1 = mask_y_0 + y_1 - y_0
mask_x_0 = max(x_0 - box[0], 0)
mask_x_1 = mask_x_0 + x_1 - x_0
im_mask[y_0:y_1, x_0:x_1] = mask[
mask_y_0 : mask_y_1, mask_x_0 : mask_x_1
]
im_mask = im_mask[:, :, None]
# OpenCV version 4.x
contours, hierarchy = cv2.findContours(
im_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
image = cv2.drawContours(image, contours, -1, 25, 3)
rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor='b', facecolor='none')
ax.annotate(classes[label] + ':' + str(np.round(score, 2)), (box[0], box[1]), color='w', fontsize=12)
ax.add_patch(rect)
ax.imshow(image)
plt.axis('off')
plt.savefig('out.png', bbox_inches='tight')
def inference(img):
input_image = Image.open(img)
orig_tensor = np.asarray(input_image)
input_tensor = preprocess(input_image)
output_names = list(map(lambda output: output.name, outputs))
input_name = sess.get_inputs()[0].name
boxes, labels, scores, masks = sess.run(output_names, {input_name: input_tensor})
display_objdetect_image(input_image, boxes, labels, scores, masks)
return 'out.png'
title="Mask R-CNN"
description="This model is a real-time neural network for object instance segmentation that detects 80 different classes."
examples=[["examplemask-rcnn.jpeg"]]
gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="file"),title=title,description=description,examples=examples).launch(enable_queue=True)