Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| import time | |
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| from model.CyueNet_models import MMS | |
| from utils1.data import transform_image | |
| # 设置GPU/CPU | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| def load_model(): | |
| """加载预训练的模型""" | |
| model = MMS() | |
| try: | |
| # 使用相对路径,模型文件将存储在HuggingFace Spaces上 | |
| model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device)) | |
| print("模型加载成功") | |
| except RuntimeError as e: | |
| print(f"加载状态字典时出现部分不匹配,错误信息: {e}") | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def process_image(image, model, testsize=256): | |
| """处理图像并返回显著性检测结果""" | |
| # 预处理图像 | |
| image = Image.fromarray(image).convert('RGB') | |
| image = transform_image(image, testsize) | |
| image = image.unsqueeze(0) | |
| image = image.to(device) | |
| # 计时 | |
| time_start = time.time() | |
| # 推理 | |
| with torch.no_grad(): | |
| x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = model(image) | |
| time_end = time.time() | |
| inference_time = time_end - time_start | |
| # 处理输出结果 | |
| res = res.sigmoid().data.cpu().numpy().squeeze() | |
| res = (res - res.min()) / (res.max() - res.min() + 1e-8) | |
| # 将输出调整为原始图像大小 | |
| original_image = np.array(Image.fromarray(image.cpu().squeeze().permute(1, 2, 0).numpy())) | |
| h, w = original_image.shape[:2] | |
| res_resized = cv2.resize(res, (w, h)) | |
| # 转换为可视化图像 | |
| res_vis = (res_resized * 255).astype(np.uint8) | |
| # 创建热力图 | |
| heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET) | |
| # 将热力图与原始图像混合 | |
| alpha = 0.5 | |
| overlayed = cv2.addWeighted(original_image, 1-alpha, heatmap, alpha, 0) | |
| # 二值化结果用于分割 | |
| _, binary_mask = cv2.threshold(res_vis, 127, 255, cv2.THRESH_BINARY) | |
| segmented = cv2.bitwise_and(original_image, original_image, mask=binary_mask) | |
| return original_image, res_vis, heatmap, overlayed, segmented, f"推理时间: {inference_time:.4f}秒" | |
| def run_demo(input_image): | |
| """Gradio界面的主函数""" | |
| if input_image is None: | |
| return [None] * 5 + ["请上传图片"] | |
| # 处理图像 | |
| original, saliency_map, heatmap, overlayed, segmented, time_info = process_image(input_image, model) | |
| return original, saliency_map, heatmap, overlayed, segmented, time_info | |
| # 加载模型 | |
| print("正在加载模型...") | |
| model = load_model() | |
| # 创建Gradio界面 | |
| with gr.Blocks(title="显著性目标检测Demo") as demo: | |
| gr.Markdown("# 显著性目标检测Demo") | |
| gr.Markdown("上传一张图片,系统将自动检测显著性区域") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="输入图像", type="numpy") | |
| submit_btn = gr.Button("开始检测") | |
| with gr.Column(): | |
| original_output = gr.Image(label="原始图像") | |
| saliency_output = gr.Image(label="显著性图") | |
| heatmap_output = gr.Image(label="热力图") | |
| overlayed_output = gr.Image(label="叠加结果") | |
| segmented_output = gr.Image(label="分割结果") | |
| time_info = gr.Textbox(label="处理信息") | |
| submit_btn.click( | |
| fn=run_demo, | |
| inputs=input_image, | |
| outputs=[original_output, saliency_output, heatmap_output, overlayed_output, segmented_output, time_info] | |
| ) | |
| gr.Markdown("## 使用说明") | |
| gr.Markdown("1. 点击'输入图像'区域上传一张图片") | |
| gr.Markdown("2. 点击'开始检测'按钮进行显著性目标检测") | |
| gr.Markdown("3. 系统将显示原始图像、显著性图、热力图、叠加结果和分割结果") | |
| # 启动Gradio应用 | |
| demo.launch() |