malaria_yolo / app.py
hongvin's picture
Update app.py
99b8720
import cv2
import numpy as np
import tensorflow as tf
#from tensorflow.compat.v1 import ConfigProto
from tensorflow.python.saved_model import tag_constants
import gradio as gr
from utils import crop_objects, format_boxes, CLASS_NAMES
def main(image,
input_size:int=416,
weights: str="yolov4-416",
iou:float=0.45,
score: float=0.50):
#config = ConfigProto()
#config.gpu_options.allow_growth = True
# load model
saved_model_loaded = tf.saved_model.load(weights, tags=[tag_constants.SERVING])
# loop through images in list and run Yolov4 model on each
# original_image = cv2.imread(image)
original_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_data = cv2.resize(original_image, (input_size, input_size))
image_data = image_data / 255.
images_data = np.asarray(image_data).astype(np.float32)
print('l1',images_data.shape)
infer = saved_model_loaded.signatures['serving_default']
batch_data = tf.constant(images_data)
pred_bbox = infer(batch_data)
for key, value in pred_bbox.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]
print('l2',pred_conf.shape)
print('l22',boxes.shape)
# run non max suppression on detections
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
scores=tf.reshape(
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
max_output_size_per_class=50,
max_total_size=50,
iou_threshold=iou,
score_threshold=score
)
print("Shape of boxes:", boxes.shape)
print("Shape of pred_conf:", pred_conf.shape)
# format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, xmax, ymax
original_h, original_w, _ = original_image.shape
bboxes = format_boxes(boxes.numpy()[0], original_h, original_w)
# hold all detection data in one variable
pred_bbox = [bboxes, scores.numpy()[0], classes.numpy()[0], valid_detections.numpy()[0]]
allowed_classes = CLASS_NAMES
crop_img = crop_objects(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB), pred_bbox, allowed_classes)
return crop_img
demo = gr.Interface(fn=main, inputs="image", outputs="image")
demo.launch()