Spaces:
Sleeping
Sleeping
from typing import List | |
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from tqdm import tqdm | |
from inference.models import YOLOWorld | |
from utils.efficient_sam import load, inference_with_boxes | |
from utils.video import ( | |
generate_file_name, | |
calculate_end_frame_index, | |
create_directory, | |
remove_files_older_than | |
) | |
MARKDOWN = """ | |
# YOLO-World + EfficientSAM 🔥 | |
<div> | |
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-yolo-world.ipynb"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;"> | |
</a> | |
<a href="https://blog.roboflow.com/what-is-yolo-world/"> | |
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;"> | |
</a> | |
<a href="https://www.youtube.com/watch?v=X7gKBGVz4vs"> | |
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;"> | |
</a> | |
<a href="https://github.com/AILab-CVC/YOLO-World"> | |
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;"> | |
</a> | |
<a href="https://arxiv.org/abs/2401.17270"> | |
<img src="https://img.shields.io/badge/arXiv-2401.17270-b31b1b.svg" alt="arXiv" style="display:inline-block;"> | |
</a> | |
</div> | |
This is a demo of zero-shot object detection and instance segmentation using | |
[YOLO-World](https://github.com/AILab-CVC/YOLO-World) and | |
[EfficientSAM](https://github.com/yformer/EfficientSAM). | |
Powered by Roboflow [Inference](https://github.com/roboflow/inference) and | |
[Supervision](https://github.com/roboflow/supervision). | |
❗ **Don't give up right away if YOLO-World doesn't detect the objects you are looking | |
for on the first try.** Use the `Configuration` tab and experiment with | |
`confidence_threshold` and `iou_threshold`. YOLO-World tends to return low `confidence` | |
values for objects outside the | |
[COCO](https://universe.roboflow.com/microsoft/coco) dataset. Check out this | |
[notebook](https://supervision.roboflow.com/develop/notebooks/zero-shot-object-detection-with-yolo-world) | |
to learn more about YOLO-World's prompting. | |
""" | |
RESULTS = "results" | |
IMAGE_EXAMPLES = [ | |
['https://media.roboflow.com/dog.jpeg', 'dog, eye, nose, tongue, car', 0.005, 0.1, True, False, False], | |
['https://media.roboflow.com/albert-4x.png', 'hand, hair', 0.005, 0.1, True, False, False], | |
] | |
VIDEO_EXAMPLES = [ | |
['https://media.roboflow.com/supervision/video-examples/croissant-1280x720.mp4', 'croissant', 0.01, 0.2, False, False, False], | |
['https://media.roboflow.com/supervision/video-examples/suitcases-1280x720.mp4', 'suitcase', 0.1, 0.2, False, False, False], | |
['https://media.roboflow.com/supervision/video-examples/tokyo-walk-1280x720.mp4', 'woman walking', 0.1, 0.2, False, False, False], | |
['https://media.roboflow.com/supervision/video-examples/wooly-mammoth-1280x720.mp4', 'mammoth', 0.01, 0.2, False, False, False], | |
] | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
EFFICIENT_SAM_MODEL = load(device=DEVICE) | |
YOLO_WORLD_MODEL = YOLOWorld(model_id="yolo_world/l") | |
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
LABEL_ANNOTATOR = sv.LabelAnnotator() | |
# creating video results directory | |
create_directory(directory_path=RESULTS) | |
def process_categories(categories: str) -> List[str]: | |
return [category.strip() for category in categories.split(',')] | |
def annotate_image( | |
input_image: np.ndarray, | |
detections: sv.Detections, | |
categories: List[str], | |
with_confidence: bool = False, | |
) -> np.ndarray: | |
labels = [ | |
( | |
f"{categories[class_id]}: {confidence:.3f}" | |
if with_confidence | |
else f"{categories[class_id]}" | |
) | |
for class_id, confidence in | |
zip(detections.class_id, detections.confidence) | |
] | |
output_image = MASK_ANNOTATOR.annotate(input_image, detections) | |
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) | |
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) | |
return output_image | |
def process_image( | |
input_image: np.ndarray, | |
categories: str, | |
confidence_threshold: float = 0.3, | |
iou_threshold: float = 0.5, | |
with_segmentation: bool = True, | |
with_confidence: bool = False, | |
with_class_agnostic_nms: bool = False, | |
) -> np.ndarray: | |
# cleanup of old video files | |
remove_files_older_than(RESULTS, 30) | |
categories = process_categories(categories) | |
YOLO_WORLD_MODEL.set_classes(categories) | |
results = YOLO_WORLD_MODEL.infer(input_image, confidence=confidence_threshold) | |
detections = sv.Detections.from_inference(results) | |
detections = detections.with_nms( | |
class_agnostic=with_class_agnostic_nms, | |
threshold=iou_threshold | |
) | |
if with_segmentation: | |
detections.mask = inference_with_boxes( | |
image=input_image, | |
xyxy=detections.xyxy, | |
model=EFFICIENT_SAM_MODEL, | |
device=DEVICE | |
) | |
output_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) | |
output_image = annotate_image( | |
input_image=output_image, | |
detections=detections, | |
categories=categories, | |
with_confidence=with_confidence | |
) | |
return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) | |
def process_video( | |
input_video: str, | |
categories: str, | |
confidence_threshold: float = 0.3, | |
iou_threshold: float = 0.5, | |
with_segmentation: bool = True, | |
with_confidence: bool = False, | |
with_class_agnostic_nms: bool = False, | |
progress=gr.Progress(track_tqdm=True) | |
) -> str: | |
# cleanup of old video files | |
remove_files_older_than(RESULTS, 30) | |
categories = process_categories(categories) | |
YOLO_WORLD_MODEL.set_classes(categories) | |
video_info = sv.VideoInfo.from_video_path(input_video) | |
total = calculate_end_frame_index(input_video) | |
frame_generator = sv.get_video_frames_generator( | |
source_path=input_video, | |
end=total | |
) | |
result_file_name = generate_file_name(extension="mp4") | |
result_file_path = os.path.join(RESULTS, result_file_name) | |
with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
for _ in tqdm(range(total), desc="Processing video..."): | |
frame = next(frame_generator) | |
results = YOLO_WORLD_MODEL.infer(frame, confidence=confidence_threshold) | |
detections = sv.Detections.from_inference(results) | |
detections = detections.with_nms( | |
class_agnostic=with_class_agnostic_nms, | |
threshold=iou_threshold | |
) | |
if with_segmentation: | |
detections.mask = inference_with_boxes( | |
image=frame, | |
xyxy=detections.xyxy, | |
model=EFFICIENT_SAM_MODEL, | |
device=DEVICE | |
) | |
frame = annotate_image( | |
input_image=frame, | |
detections=detections, | |
categories=categories, | |
with_confidence=with_confidence | |
) | |
sink.write_frame(frame) | |
return result_file_path | |
confidence_threshold_component = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=0.3, | |
step=0.01, | |
label="Confidence Threshold", | |
info=( | |
"The confidence threshold for the YOLO-World model. Lower the threshold to " | |
"reduce false negatives, enhancing the model's sensitivity to detect " | |
"sought-after objects. Conversely, increase the threshold to minimize false " | |
"positives, preventing the model from identifying objects it shouldn't." | |
)) | |
iou_threshold_component = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=0.5, | |
step=0.01, | |
label="IoU Threshold", | |
info=( | |
"The Intersection over Union (IoU) threshold for non-maximum suppression. " | |
"Decrease the value to lessen the occurrence of overlapping bounding boxes, " | |
"making the detection process stricter. On the other hand, increase the value " | |
"to allow more overlapping bounding boxes, accommodating a broader range of " | |
"detections." | |
)) | |
with_segmentation_component = gr.Checkbox( | |
value=True, | |
label="With Segmentation", | |
info=( | |
"Whether to run EfficientSAM for instance segmentation." | |
) | |
) | |
with_confidence_component = gr.Checkbox( | |
value=False, | |
label="Display Confidence", | |
info=( | |
"Whether to display the confidence of the detected objects." | |
) | |
) | |
with_class_agnostic_nms_component = gr.Checkbox( | |
value=False, | |
label="Use Class-Agnostic NMS", | |
info=( | |
"Suppress overlapping bounding boxes across all classes." | |
) | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Accordion("Configuration", open=False): | |
confidence_threshold_component.render() | |
iou_threshold_component.render() | |
with gr.Row(): | |
with_segmentation_component.render() | |
with_confidence_component.render() | |
with_class_agnostic_nms_component.render() | |
with gr.Tab(label="Image"): | |
with gr.Row(): | |
input_image_component = gr.Image( | |
type='numpy', | |
label='Input Image' | |
) | |
output_image_component = gr.Image( | |
type='numpy', | |
label='Output Image' | |
) | |
with gr.Row(): | |
image_categories_text_component = gr.Textbox( | |
label='Categories', | |
placeholder='comma separated list of categories', | |
scale=7 | |
) | |
image_submit_button_component = gr.Button( | |
value='Submit', | |
scale=1, | |
variant='primary' | |
) | |
gr.Examples( | |
fn=process_image, | |
examples=IMAGE_EXAMPLES, | |
inputs=[ | |
input_image_component, | |
image_categories_text_component, | |
confidence_threshold_component, | |
iou_threshold_component, | |
with_segmentation_component, | |
with_confidence_component, | |
with_class_agnostic_nms_component | |
], | |
outputs=output_image_component | |
) | |
with gr.Tab(label="Video"): | |
with gr.Row(): | |
input_video_component = gr.Video( | |
label='Input Video' | |
) | |
output_video_component = gr.Video( | |
label='Output Video' | |
) | |
with gr.Row(): | |
video_categories_text_component = gr.Textbox( | |
label='Categories', | |
placeholder='comma separated list of categories', | |
scale=7 | |
) | |
video_submit_button_component = gr.Button( | |
value='Submit', | |
scale=1, | |
variant='primary' | |
) | |
gr.Examples( | |
fn=process_video, | |
examples=VIDEO_EXAMPLES, | |
inputs=[ | |
input_video_component, | |
video_categories_text_component, | |
confidence_threshold_component, | |
iou_threshold_component, | |
with_segmentation_component, | |
with_confidence_component, | |
with_class_agnostic_nms_component | |
], | |
outputs=output_image_component | |
) | |
image_submit_button_component.click( | |
fn=process_image, | |
inputs=[ | |
input_image_component, | |
image_categories_text_component, | |
confidence_threshold_component, | |
iou_threshold_component, | |
with_segmentation_component, | |
with_confidence_component, | |
with_class_agnostic_nms_component | |
], | |
outputs=output_image_component | |
) | |
video_submit_button_component.click( | |
fn=process_video, | |
inputs=[ | |
input_video_component, | |
video_categories_text_component, | |
confidence_threshold_component, | |
iou_threshold_component, | |
with_segmentation_component, | |
with_confidence_component, | |
with_class_agnostic_nms_component | |
], | |
outputs=output_video_component | |
) | |
demo.launch(debug=False, show_error=True, max_threads=1) | |