YOLO-ARENA / app.py
SkalskiP's picture
YOLOv10 added
34cb512
raw
history blame
5.31 kB
from typing import Tuple
import gradio as gr
import numpy as np
import supervision as sv
from inference import get_model
MARKDOWN = """
<h1 style='text-align: center'>YOLO-ARENA ๐ŸŸ๏ธ</h1>
Welcome to YOLO-Arena! This demo showcases the performance of various YOLO models:
- YOLOv8
- YOLOv9
- YOLOv10
- YOLO-NAS
Powered by Roboflow [Inference](https://github.com/roboflow/inference) and
[Supervision](https://github.com/roboflow/supervision).
"""
IMAGE_EXAMPLES = [
['https://media.roboflow.com/dog.jpeg', 0.3]
]
YOLO_V8_MODEL = get_model(model_id="yolov8m-640")
YOLO_NAS_MODEL = get_model(model_id="coco/15")
YOLO_V9_MODEL = get_model(model_id="coco/17")
YOLO_V10_MODEL = get_model(model_id="coco/22")
LABEL_ANNOTATORS = sv.LabelAnnotator(text_color=sv.Color.black())
BOUNDING_BOX_ANNOTATORS = sv.BoundingBoxAnnotator()
def detect_and_annotate(
model,
input_image: np.ndarray,
confidence_threshold: float,
iou_threshold: float
) -> np.ndarray:
result = model.infer(
input_image,
confidence=confidence_threshold,
iou_threshold=iou_threshold
)[0]
detections = sv.Detections.from_inference(result)
annotated_image = input_image.copy()
annotated_image = BOUNDING_BOX_ANNOTATORS.annotate(
scene=annotated_image, detections=detections)
annotated_image = LABEL_ANNOTATORS.annotate(
scene=annotated_image, detections=detections)
return annotated_image
def process_image(
input_image: np.ndarray,
confidence_threshold: float,
iou_threshold: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
yolo_v8_annotated_image = detect_and_annotate(
YOLO_V8_MODEL, input_image, confidence_threshold, iou_threshold)
yolo_nas_annotated_image = detect_and_annotate(
YOLO_NAS_MODEL, input_image, confidence_threshold, iou_threshold)
yolo_v9_annotated_image = detect_and_annotate(
YOLO_V9_MODEL, input_image, confidence_threshold, iou_threshold)
yolo_10_annotated_image = detect_and_annotate(
YOLO_V10_MODEL, input_image, confidence_threshold, iou_threshold)
return (
yolo_v8_annotated_image,
yolo_nas_annotated_image,
yolo_v9_annotated_image,
yolo_10_annotated_image
)
confidence_threshold_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.3,
step=0.01,
label="Confidence Threshold",
info=(
"The confidence threshold for the YOLO model. Lower the threshold to "
"reduce false negatives, enhancing the model's sensitivity to detect "
"sought-after objects. Conversely, increase the threshold to minimize false "
"positives, preventing the model from identifying objects it shouldn't."
))
iou_threshold_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.5,
step=0.01,
label="IoU Threshold",
info=(
"The Intersection over Union (IoU) threshold for non-maximum suppression. "
"Decrease the value to lessen the occurrence of overlapping bounding boxes, "
"making the detection process stricter. On the other hand, increase the value "
"to allow more overlapping bounding boxes, accommodating a broader range of "
"detections."
))
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Accordion("Configuration", open=False):
confidence_threshold_component.render()
iou_threshold_component.render()
with gr.Row():
input_image_component = gr.Image(
type='numpy',
label='Input'
)
with gr.Column():
with gr.Row():
yolo_v8_output_image_component = gr.Image(
type='numpy',
label='YOLOv8m @ 640x640'
)
yolo_nas_output_image_component = gr.Image(
type='numpy',
label='YOLO-NAS M @ 640x640'
)
with gr.Row():
yolo_v9_output_image_component = gr.Image(
type='numpy',
label='YOLOv9c @ 640x640'
)
yolo_v10_output_image_component = gr.Image(
type='numpy',
label='YOLOv10m @ 640x640'
)
submit_button_component = gr.Button(
value='Submit',
scale=1,
variant='primary'
)
gr.Examples(
fn=process_image,
examples=IMAGE_EXAMPLES,
inputs=[
input_image_component,
confidence_threshold_component,
iou_threshold_component
],
outputs=[
yolo_v8_output_image_component,
yolo_nas_output_image_component,
yolo_v9_output_image_component,
yolo_v10_output_image_component
]
)
submit_button_component.click(
fn=process_image,
inputs=[
input_image_component,
confidence_threshold_component,
iou_threshold_component
],
outputs=[
yolo_v8_output_image_component,
yolo_nas_output_image_component,
yolo_v9_output_image_component,
yolo_v10_output_image_component
]
)
demo.launch(debug=False, show_error=True, max_threads=1)