Spaces:
Runtime error
Runtime error
| from ultralytics import YOLO | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from collections import defaultdict | |
| # 初始化模型 | |
| model = None | |
| def load_model(): | |
| global model | |
| if model is None: | |
| model = YOLO('./yolo11x-seg.pt') | |
| return model | |
| def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks): | |
| # 加载模型 | |
| model = load_model() | |
| # 确保图像是BGR格式 | |
| if len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| elif image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) | |
| # 使用模型进行预测 | |
| results = model( | |
| image, | |
| conf=conf_threshold, | |
| iou=iou_threshold, | |
| device='cpu', | |
| retina_masks=use_retina_masks | |
| ) | |
| result = results[0] | |
| # 按类别分组存储结果 | |
| class_images = defaultdict(lambda: image.copy()) | |
| detected_classes = set() | |
| if result.masks is not None: | |
| names = model.names | |
| # 处理每个检测结果 | |
| for seg, box, cls in zip(result.masks, result.boxes, result.boxes.cls): | |
| class_id = int(cls) | |
| class_name = names[class_id] | |
| detected_classes.add(class_name) | |
| output_image = class_images[class_name] | |
| # 处理分割掩码 | |
| segment = seg.data[0].cpu().numpy() | |
| segment = cv2.resize(segment, (output_image.shape[1], output_image.shape[0])) | |
| # 生成颜色 | |
| color_mask = np.array([hash(class_name) % 256, | |
| hash(class_name * 2) % 256, | |
| hash(class_name * 3) % 256], dtype=np.uint8) | |
| # 应用掩码 | |
| mask_area = segment > mask_threshold | |
| overlay = output_image.copy() | |
| overlay[mask_area] = color_mask | |
| cv2.addWeighted(overlay, 0.4, output_image, 0.6, 0, output_image) | |
| # 添加边界框和标签 | |
| conf = float(box.conf) | |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
| cv2.rectangle(output_image, (x1, y1), (x2, y2), | |
| color_mask.tolist(), line_thickness) | |
| # 添加标签 | |
| label = f"{class_name} {conf:.2f}" | |
| font_scale = 0.6 * line_thickness / 2 | |
| (label_width, label_height), _ = cv2.getTextSize( | |
| label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness) | |
| cv2.rectangle(output_image, | |
| (x1, y1 - label_height - 10), | |
| (x1 + label_width, y1), | |
| color_mask.tolist(), -1) | |
| cv2.putText(output_image, label, (x1, y1 - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, | |
| (255, 255, 255), line_thickness, cv2.LINE_AA) | |
| class_images[class_name] = output_image | |
| # 准备Gallery输出 | |
| gallery_output = [] | |
| # 添加完整结果 | |
| if detected_classes: | |
| full_result = image.copy() | |
| for class_name in detected_classes: | |
| cv2.addWeighted(class_images[class_name], 0.5, full_result, 0.5, 0, full_result) | |
| gallery_output.append((full_result, "完整结果")) | |
| # 添加各个类别的结果 | |
| for class_name in detected_classes: | |
| gallery_output.append((class_images[class_name], class_name)) | |
| return gallery_output if gallery_output else None | |
| def create_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# YOLO 图像分割") | |
| gr.Markdown("上传一张图片,模型将对图片进行实例分割。每个检测到的类别将单独显示。") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image() | |
| with gr.Row(): | |
| conf_threshold = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.25, step=0.05, | |
| label="置信度阈值", info="检测置信度的最小值" | |
| ) | |
| iou_threshold = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.7, step=0.05, | |
| label="IOU阈值", info="非极大值抑制的IOU阈值" | |
| ) | |
| with gr.Row(): | |
| mask_threshold = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.5, step=0.05, | |
| label="掩码阈值", info="分割掩码的阈值" | |
| ) | |
| line_thickness = gr.Slider( | |
| minimum=1, maximum=5, value=2, step=1, | |
| label="线条粗细", info="边界框和文本的粗细" | |
| ) | |
| with gr.Row(): | |
| retina_masks = gr.Checkbox( | |
| label="高分辨率掩码", | |
| value=True, | |
| info="启用高分辨率分割掩码(可能会降低速度)" | |
| ) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="分割结果", | |
| show_label=True, | |
| columns=2, | |
| rows=2, | |
| height=600, | |
| object_fit="contain" | |
| ) | |
| submit_btn = gr.Button("开始分割") | |
| submit_btn.click( | |
| fn=segment_image, | |
| inputs=[ | |
| input_image, | |
| conf_threshold, | |
| iou_threshold, | |
| mask_threshold, | |
| line_thickness, | |
| retina_masks | |
| ], | |
| outputs=output_gallery | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch() |