File size: 4,312 Bytes
74d7b02 d4461b5 74d7b02 d4461b5 74d7b02 d4461b5 74d7b02 8f78001 42f5968 74d7b02 be786c9 74d7b02 be786c9 74d7b02 be786c9 74d7b02 be786c9 74d7b02 be786c9 74d7b02 d4461b5 be786c9 d4461b5 be786c9 d4461b5 be786c9 74d7b02 d4461b5 74d7b02 d4461b5 74d7b02 be786c9 d4461b5 74d7b02 cd645f2 74d7b02 be786c9 42f5968 74d7b02 d4461b5 74d7b02 be786c9 42f5968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
import sahi
import torch
from ultralyticsplus import YOLO, render_model_output
# Download sample images
sahi.utils.file.download_from_url(
"https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
"highway.jpg",
)
sahi.utils.file.download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/tests/data/small-vehicles1.jpeg",
"small-vehicles1.jpeg",
)
sahi.utils.file.download_from_url(
"https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/zidane.jpg",
"zidane.jpg",
)
# List of YOLOv8 segmentation models
model_names = [
"yolov8n-seg.pt",
"yolov8s-seg.pt",
"yolov8m-seg.pt",
"yolov8l-seg.pt",
"yolov8x-seg.pt",
]
current_model_name = "yolov8m-seg.pt"
model = YOLO(current_model_name)
def yolov8_inference(
image: gr.Image = None,
model_name: str = None,
image_size: int = 640,
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
):
"""
YOLOv8 inference function to return masks and label names for each detected object
Args:
image: Input image
model_name: Name of the model
image_size: Image size
conf_threshold: Confidence threshold
iou_threshold: IOU threshold
Returns:
Object masks, coordinates, and label names
"""
global model
global current_model_name
# Check if a new model is selected
if model_name != current_model_name:
model = YOLO(model_name)
current_model_name = model_name
# Set the confidence and IOU thresholds
model.overrides["conf"] = conf_threshold
model.overrides["iou"] = iou_threshold
# Perform model prediction
results = model.predict(image, imgsz=image_size, return_outputs=True)
# Initialize an empty list to store the output
output = []
# Iterate over the results
for result in results:
# Check if segmentation masks are available
if 'masks' in result and result['masks'] is not None:
masks = result['masks']['data']
for i, (mask, box) in enumerate(zip(masks, result['boxes'])):
label = model.names[int(result['boxes']['cls'][i])]
mask_coords = mask.tolist() # Convert mask coordinates to list format
output.append({"label": label, "mask_coords": mask_coords})
else:
# If masks are not available, just extract bounding box information
for i, box in enumerate(result['boxes']):
label = model.names[int(result['boxes']['cls'][i])]
bbox = box['xyxy'].tolist() # Bounding box coordinates
output.append({"label": label, "bbox_coords": bbox})
return output
# Define Gradio interface inputs and outputs
inputs = [
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(
model_names,
value=current_model_name,
label="Model type",
),
gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
gr.Slider(
minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"
),
gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
]
# Output is a dictionary containing label names and coordinates of masks or boxes
outputs = gr.JSON(label="Output Masks and Labels")
title = "Ultralytics YOLOv8 Segmentation Demo"
# Example images for the interface
examples = [
["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
]
# Build the Gradio demo app with POST functionality
demo_app = gr.Interface(
fn=yolov8_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=False, # Set to False to avoid caching issues
theme="default",
)
# Launch the app with API-enabled functionality
demo_app.queue().launch(
enable_queue=True, # Allow for API-style interactions
debug=True, # Show detailed errors in case of issues
server_name="0.0.0.0", # Host on all IPs
server_port=7860, # Custom port for accessing the app
share=True # To make the app accessible from a URL
) |