YOLOv11-seg / app.py
robot2no1's picture
Update app.py
897a1ee verified
raw
history blame contribute delete
5.93 kB
from ultralytics import YOLO
import gradio as gr
import cv2
import numpy as np
from collections import defaultdict
# 初始化模型
model = None
def load_model():
global model
if model is None:
model = YOLO('./yolo11x-seg.pt')
return model
def segment_image(image, conf_threshold, iou_threshold, mask_threshold, line_thickness, use_retina_masks):
# 加载模型
model = load_model()
# 确保图像是BGR格式
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
# 使用模型进行预测
results = model(
image,
conf=conf_threshold,
iou=iou_threshold,
device='cpu',
retina_masks=use_retina_masks
)
result = results[0]
# 按类别分组存储结果
class_images = defaultdict(lambda: image.copy())
detected_classes = set()
if result.masks is not None:
names = model.names
# 处理每个检测结果
for seg, box, cls in zip(result.masks, result.boxes, result.boxes.cls):
class_id = int(cls)
class_name = names[class_id]
detected_classes.add(class_name)
output_image = class_images[class_name]
# 处理分割掩码
segment = seg.data[0].cpu().numpy()
segment = cv2.resize(segment, (output_image.shape[1], output_image.shape[0]))
# 生成颜色
color_mask = np.array([hash(class_name) % 256,
hash(class_name * 2) % 256,
hash(class_name * 3) % 256], dtype=np.uint8)
# 应用掩码
mask_area = segment > mask_threshold
overlay = output_image.copy()
overlay[mask_area] = color_mask
cv2.addWeighted(overlay, 0.4, output_image, 0.6, 0, output_image)
# 添加边界框和标签
conf = float(box.conf)
x1, y1, x2, y2 = map(int, box.xyxy[0])
cv2.rectangle(output_image, (x1, y1), (x2, y2),
color_mask.tolist(), line_thickness)
# 添加标签
label = f"{class_name} {conf:.2f}"
font_scale = 0.6 * line_thickness / 2
(label_width, label_height), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)
cv2.rectangle(output_image,
(x1, y1 - label_height - 10),
(x1 + label_width, y1),
color_mask.tolist(), -1)
cv2.putText(output_image, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, font_scale,
(255, 255, 255), line_thickness, cv2.LINE_AA)
class_images[class_name] = output_image
# 准备Gallery输出
gallery_output = []
# 添加完整结果
if detected_classes:
full_result = image.copy()
for class_name in detected_classes:
cv2.addWeighted(class_images[class_name], 0.5, full_result, 0.5, 0, full_result)
gallery_output.append((full_result, "完整结果"))
# 添加各个类别的结果
for class_name in detected_classes:
gallery_output.append((class_images[class_name], class_name))
return gallery_output if gallery_output else None
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# YOLO 图像分割")
gr.Markdown("上传一张图片,模型将对图片进行实例分割。每个检测到的类别将单独显示。")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image()
with gr.Row():
conf_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.25, step=0.05,
label="置信度阈值", info="检测置信度的最小值"
)
iou_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.05,
label="IOU阈值", info="非极大值抑制的IOU阈值"
)
with gr.Row():
mask_threshold = gr.Slider(
minimum=0.1, maximum=1.0, value=0.5, step=0.05,
label="掩码阈值", info="分割掩码的阈值"
)
line_thickness = gr.Slider(
minimum=1, maximum=5, value=2, step=1,
label="线条粗细", info="边界框和文本的粗细"
)
with gr.Row():
retina_masks = gr.Checkbox(
label="高分辨率掩码",
value=True,
info="启用高分辨率分割掩码(可能会降低速度)"
)
with gr.Column(scale=1):
output_gallery = gr.Gallery(
label="分割结果",
show_label=True,
columns=2,
rows=2,
height=600,
object_fit="contain"
)
submit_btn = gr.Button("开始分割")
submit_btn.click(
fn=segment_image,
inputs=[
input_image,
conf_threshold,
iou_threshold,
mask_threshold,
line_thickness,
retina_masks
],
outputs=output_gallery
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()