Spaces:
Running
Running
from ultralytics import YOLO | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
from collections import defaultdict | |
# 初始化模型 | |
model = None | |
def load_model(): | |
global model | |
if model is None: | |
model = YOLO('./yolo11x-seg.pt') | |
return model | |
def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks): | |
# 加载模型 | |
model = load_model() | |
# 确保图像是BGR格式 | |
if len(image.shape) == 2: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
elif image.shape[2] == 4: | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) | |
# 使用模型进行预测 | |
results = model( | |
image, | |
conf=conf_threshold, | |
iou=iou_threshold, | |
device='cpu', | |
retina_masks=use_retina_masks | |
) | |
result = results[0] | |
# 按类别分组存储结果 | |
class_images = defaultdict(lambda: image.copy()) | |
detected_classes = set() | |
if result.masks is not None: | |
names = model.names | |
# 处理每个检测结果 | |
for seg, box, cls in zip(result.masks, result.boxes, result.boxes.cls): | |
class_id = int(cls) | |
class_name = names[class_id] | |
detected_classes.add(class_name) | |
output_image = class_images[class_name] | |
# 处理分割掩码 | |
segment = seg.data[0].cpu().numpy() | |
segment = cv2.resize(segment, (output_image.shape[1], output_image.shape[0])) | |
# 生成颜色 | |
color_mask = np.array([hash(class_name) % 256, | |
hash(class_name * 2) % 256, | |
hash(class_name * 3) % 256], dtype=np.uint8) | |
# 应用掩码 | |
mask_area = segment > mask_threshold | |
overlay = output_image.copy() | |
overlay[mask_area] = color_mask | |
cv2.addWeighted(overlay, 0.4, output_image, 0.6, 0, output_image) | |
# 添加边界框和标签 | |
conf = float(box.conf) | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
cv2.rectangle(output_image, (x1, y1), (x2, y2), | |
color_mask.tolist(), line_thickness) | |
# 添加标签 | |
label = f"{class_name} {conf:.2f}" | |
font_scale = 0.6 * line_thickness / 2 | |
(label_width, label_height), _ = cv2.getTextSize( | |
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness) | |
cv2.rectangle(output_image, | |
(x1, y1 - label_height - 10), | |
(x1 + label_width, y1), | |
color_mask.tolist(), -1) | |
cv2.putText(output_image, label, (x1, y1 - 5), | |
cv2.FONT_HERSHEY_SIMPLEX, font_scale, | |
(255, 255, 255), line_thickness, cv2.LINE_AA) | |
class_images[class_name] = output_image | |
# 准备Gallery输出 | |
gallery_output = [] | |
# 添加完整结果 | |
if detected_classes: | |
full_result = image.copy() | |
for class_name in detected_classes: | |
cv2.addWeighted(class_images[class_name], 0.5, full_result, 0.5, 0, full_result) | |
gallery_output.append((full_result, "完整结果")) | |
# 添加各个类别的结果 | |
for class_name in detected_classes: | |
gallery_output.append((class_images[class_name], class_name)) | |
return gallery_output if gallery_output else None | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# YOLO 图像分割") | |
gr.Markdown("上传一张图片,模型将对图片进行实例分割。每个检测到的类别将单独显示。") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image() | |
with gr.Row(): | |
conf_threshold = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.25, step=0.05, | |
label="置信度阈值", info="检测置信度的最小值" | |
) | |
iou_threshold = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.7, step=0.05, | |
label="IOU阈值", info="非极大值抑制的IOU阈值" | |
) | |
with gr.Row(): | |
mask_threshold = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.5, step=0.05, | |
label="掩码阈值", info="分割掩码的阈值" | |
) | |
line_thickness = gr.Slider( | |
minimum=1, maximum=5, value=2, step=1, | |
label="线条粗细", info="边界框和文本的粗细" | |
) | |
with gr.Row(): | |
retina_masks = gr.Checkbox( | |
label="高分辨率掩码", | |
value=True, | |
info="启用高分辨率分割掩码(可能会降低速度)" | |
) | |
with gr.Column(scale=1): | |
output_gallery = gr.Gallery( | |
label="分割结果", | |
show_label=True, | |
columns=2, | |
rows=2, | |
height=600, | |
object_fit="contain" | |
) | |
submit_btn = gr.Button("开始分割") | |
submit_btn.click( | |
fn=segment_image, | |
inputs=[ | |
input_image, | |
conf_threshold, | |
iou_threshold, | |
mask_threshold, | |
line_thickness, | |
retina_masks | |
], | |
outputs=output_gallery | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() |