File size: 5,934 Bytes
c48ff3a
 
 
 
 
 
9b884b7
 
 
 
 
 
897a1ee
9b884b7
c48ff3a
 
9b884b7
 
 
c48ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b884b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c48ff3a
 
9b884b7
 
 
 
 
 
 
 
 
 
 
 
 
 
79d1225
9b884b7
c48ff3a
 
9b884b7
897a1ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from ultralytics import YOLO
import gradio as gr
import cv2
import numpy as np
from collections import defaultdict

# 初始化模型
model = None

def load_model():
    global model
    if model is None:
        model = YOLO('./yolo11x-seg.pt')  
    return model

def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks):
    # 加载模型
    model = load_model()
    
    # 确保图像是BGR格式
    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    elif image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
    
    # 使用模型进行预测
    results = model(
        image,
        conf=conf_threshold,
        iou=iou_threshold,
        device='cpu',
        retina_masks=use_retina_masks
    )
    result = results[0]
    
    # 按类别分组存储结果
    class_images = defaultdict(lambda: image.copy())
    detected_classes = set()
    
    if result.masks is not None:
        names = model.names
        # 处理每个检测结果
        for seg, box, cls in zip(result.masks, result.boxes, result.boxes.cls):
            class_id = int(cls)
            class_name = names[class_id]
            detected_classes.add(class_name)
            
            output_image = class_images[class_name]
            
            # 处理分割掩码
            segment = seg.data[0].cpu().numpy()
            segment = cv2.resize(segment, (output_image.shape[1], output_image.shape[0]))
            
            # 生成颜色
            color_mask = np.array([hash(class_name) % 256, 
                                 hash(class_name * 2) % 256,
                                 hash(class_name * 3) % 256], dtype=np.uint8)
            
            # 应用掩码
            mask_area = segment > mask_threshold
            overlay = output_image.copy()
            overlay[mask_area] = color_mask
            cv2.addWeighted(overlay, 0.4, output_image, 0.6, 0, output_image)
            
            # 添加边界框和标签
            conf = float(box.conf)
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            cv2.rectangle(output_image, (x1, y1), (x2, y2), 
                         color_mask.tolist(), line_thickness)
            
            # 添加标签
            label = f"{class_name} {conf:.2f}"
            font_scale = 0.6 * line_thickness / 2
            (label_width, label_height), _ = cv2.getTextSize(
                label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)
            
            cv2.rectangle(output_image, 
                         (x1, y1 - label_height - 10), 
                         (x1 + label_width, y1), 
                         color_mask.tolist(), -1)
            
            cv2.putText(output_image, label, (x1, y1 - 5),
                       cv2.FONT_HERSHEY_SIMPLEX, font_scale, 
                       (255, 255, 255), line_thickness, cv2.LINE_AA)
            
            class_images[class_name] = output_image
    
    # 准备Gallery输出
    gallery_output = []
    
    # 添加完整结果
    if detected_classes:
        full_result = image.copy()
        for class_name in detected_classes:
            cv2.addWeighted(class_images[class_name], 0.5, full_result, 0.5, 0, full_result)
        gallery_output.append((full_result, "完整结果"))
    
        # 添加各个类别的结果
        for class_name in detected_classes:
            gallery_output.append((class_images[class_name], class_name))
    
    return gallery_output if gallery_output else None

def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# YOLO 图像分割")
        gr.Markdown("上传一张图片,模型将对图片进行实例分割。每个检测到的类别将单独显示。")
        
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image()
                with gr.Row():
                    conf_threshold = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.25, step=0.05,
                        label="置信度阈值", info="检测置信度的最小值"
                    )
                    iou_threshold = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.7, step=0.05,
                        label="IOU阈值", info="非极大值抑制的IOU阈值"
                    )
                with gr.Row():
                    mask_threshold = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.5, step=0.05,
                        label="掩码阈值", info="分割掩码的阈值"
                    )
                    line_thickness = gr.Slider(
                        minimum=1, maximum=5, value=2, step=1,
                        label="线条粗细", info="边界框和文本的粗细"
                    )
                with gr.Row():
                    retina_masks = gr.Checkbox(
                        label="高分辨率掩码",
                        value=True,
                        info="启用高分辨率分割掩码(可能会降低速度)"
                    )
            
            with gr.Column(scale=1):
                output_gallery = gr.Gallery(
                    label="分割结果",
                    show_label=True,
                    columns=2,
                    rows=2,
                    height=600,
                    object_fit="contain"
                )
        
        submit_btn = gr.Button("开始分割")
        
        submit_btn.click(
            fn=segment_image,
            inputs=[
                input_image,
                conf_threshold,
                iou_threshold,
                mask_threshold,
                line_thickness,
                retina_masks
            ],
            outputs=output_gallery
        )
    
    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch()