Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
#from tensorflow.compat.v1 import ConfigProto | |
from tensorflow.python.saved_model import tag_constants | |
import gradio as gr | |
from utils import crop_objects, format_boxes, CLASS_NAMES | |
def main(image, | |
input_size:int=416, | |
weights: str="yolov4-416", | |
iou:float=0.45, | |
score: float=0.50): | |
#config = ConfigProto() | |
#config.gpu_options.allow_growth = True | |
# load model | |
saved_model_loaded = tf.saved_model.load(weights, tags=[tag_constants.SERVING]) | |
# loop through images in list and run Yolov4 model on each | |
# original_image = cv2.imread(image) | |
original_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_data = cv2.resize(original_image, (input_size, input_size)) | |
image_data = image_data / 255. | |
images_data = np.asarray(image_data).astype(np.float32) | |
print('l1',images_data.shape) | |
infer = saved_model_loaded.signatures['serving_default'] | |
batch_data = tf.constant(images_data) | |
pred_bbox = infer(batch_data) | |
for key, value in pred_bbox.items(): | |
boxes = value[:, :, 0:4] | |
pred_conf = value[:, :, 4:] | |
print('l2',pred_conf.shape) | |
print('l22',boxes.shape) | |
# run non max suppression on detections | |
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression( | |
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)), | |
scores=tf.reshape( | |
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])), | |
max_output_size_per_class=50, | |
max_total_size=50, | |
iou_threshold=iou, | |
score_threshold=score | |
) | |
print("Shape of boxes:", boxes.shape) | |
print("Shape of pred_conf:", pred_conf.shape) | |
# format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, xmax, ymax | |
original_h, original_w, _ = original_image.shape | |
bboxes = format_boxes(boxes.numpy()[0], original_h, original_w) | |
# hold all detection data in one variable | |
pred_bbox = [bboxes, scores.numpy()[0], classes.numpy()[0], valid_detections.numpy()[0]] | |
allowed_classes = CLASS_NAMES | |
crop_img = crop_objects(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB), pred_bbox, allowed_classes) | |
return crop_img | |
demo = gr.Interface(fn=main, inputs="image", outputs="image") | |
demo.launch() |