Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
app.py
CHANGED
|
@@ -1,200 +1,183 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import tempfile
|
| 3 |
-
import os
|
| 4 |
import traceback
|
|
|
|
|
|
|
| 5 |
from model import init, dehaze_inference, demoiring_inference, deblur_inference, get_model_status
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def safe_init():
|
| 11 |
-
"""
|
| 12 |
try:
|
| 13 |
-
print("
|
| 14 |
init()
|
| 15 |
-
print("
|
| 16 |
return True
|
| 17 |
except Exception as e:
|
| 18 |
-
print(f"
|
| 19 |
-
print(f"详细错误: {traceback.format_exc()}")
|
| 20 |
return False
|
| 21 |
|
| 22 |
|
| 23 |
-
#
|
| 24 |
model_loaded = safe_init()
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
print(f"输入图像尺寸: {input_image.size}")
|
| 33 |
|
| 34 |
-
# 图像预处理:确保图像尺寸合适
|
| 35 |
-
width, height = input_image.size
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
scale = max(min_size / width, min_size / height)
|
| 41 |
-
new_width = int(width * scale)
|
| 42 |
-
new_height = int(height * scale)
|
| 43 |
-
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
|
| 44 |
-
print(f"图像已放大至: {input_image.size}")
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
if width > max_size or height > max_size:
|
| 49 |
-
scale = min(max_size / width, max_size / height)
|
| 50 |
-
new_width = int(width * scale)
|
| 51 |
-
new_height = int(height * scale)
|
| 52 |
-
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
|
| 53 |
-
print(f"图像已缩小至: {input_image.size}")
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
new_height = ((height + 7) // 8) * 8
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
return input_image
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
def create_error_image(error_message):
|
| 68 |
-
"""创建错误提示图像"""
|
| 69 |
try:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
draw = ImageDraw.Draw(error_image)
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
font = None
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
"""统一的图像处理函数"""
|
| 92 |
-
try:
|
| 93 |
-
print(f"开始处理图像,任务类型: {task_type}")
|
| 94 |
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
input_bytes = f.read()
|
| 115 |
-
print(f"读取到 {len(input_bytes)} 字节的图像数据")
|
| 116 |
-
|
| 117 |
-
print("开始推理...")
|
| 118 |
-
|
| 119 |
-
# 根据任务类型选择推理函数
|
| 120 |
-
if task_type == "dehaze":
|
| 121 |
-
result_bytes = dehaze_inference(input_bytes)
|
| 122 |
-
elif task_type == "demoiring":
|
| 123 |
-
result_bytes = demoiring_inference(input_bytes)
|
| 124 |
-
elif task_type == "deblur":
|
| 125 |
-
result_bytes = deblur_inference(input_bytes)
|
| 126 |
-
else:
|
| 127 |
-
raise Exception(f"不支持的任务类型: {task_type}")
|
| 128 |
-
|
| 129 |
-
print(f"推理完成,输出 {len(result_bytes)} 字节")
|
| 130 |
-
|
| 131 |
-
# 清理临时文件
|
| 132 |
-
os.unlink(temp_input_path)
|
| 133 |
-
|
| 134 |
-
# 将字节流转换为PIL图像对象
|
| 135 |
result_image = Image.open(BytesIO(result_bytes))
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
|
| 138 |
return result_image
|
| 139 |
|
| 140 |
except Exception as e:
|
| 141 |
-
error_msg = f"
|
| 142 |
-
print(error_msg)
|
| 143 |
-
|
| 144 |
-
return create_error_image(error_msg)
|
| 145 |
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
return process_image(
|
| 150 |
|
| 151 |
|
| 152 |
-
def
|
| 153 |
-
""
|
| 154 |
-
return process_image(input_image, "demoiring")
|
| 155 |
|
| 156 |
|
| 157 |
-
def
|
| 158 |
-
""
|
| 159 |
-
return process_image(input_image, "deblur")
|
| 160 |
|
| 161 |
|
| 162 |
-
def
|
| 163 |
-
"""
|
| 164 |
status = get_model_status()
|
| 165 |
-
info = "
|
| 166 |
-
info += f"
|
| 167 |
-
info += f"
|
| 168 |
-
info +=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
return info
|
| 170 |
|
| 171 |
|
| 172 |
-
#
|
| 173 |
def create_interface():
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
with gr.Row():
|
| 180 |
-
|
| 181 |
-
label="
|
| 182 |
-
value=
|
| 183 |
interactive=False,
|
| 184 |
-
lines=
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
-
#
|
| 188 |
with gr.Tabs():
|
| 189 |
-
#
|
| 190 |
-
with gr.TabItem("🌫️
|
| 191 |
-
gr.Markdown("### 上传有雾图像,模型将自动去除雾气")
|
| 192 |
with gr.Row():
|
| 193 |
with gr.Column():
|
| 194 |
dehaze_input = gr.Image(
|
| 195 |
type="pil",
|
| 196 |
label="上传有雾图像",
|
| 197 |
-
height=
|
| 198 |
)
|
| 199 |
dehaze_btn = gr.Button("开始去雾", variant="primary")
|
| 200 |
|
|
@@ -202,24 +185,19 @@ def create_interface():
|
|
| 202 |
dehaze_output = gr.Image(
|
| 203 |
type="pil",
|
| 204 |
label="去雾结果",
|
| 205 |
-
height=
|
| 206 |
)
|
| 207 |
|
| 208 |
-
dehaze_btn.click(
|
| 209 |
-
fn=dehaze_image,
|
| 210 |
-
inputs=[dehaze_input],
|
| 211 |
-
outputs=[dehaze_output]
|
| 212 |
-
)
|
| 213 |
|
| 214 |
-
#
|
| 215 |
with gr.TabItem("🔍 去摩尔纹"):
|
| 216 |
-
gr.Markdown("### 上传有摩尔纹的图像,模型将去除摩尔纹效应")
|
| 217 |
with gr.Row():
|
| 218 |
with gr.Column():
|
| 219 |
demoiring_input = gr.Image(
|
| 220 |
type="pil",
|
| 221 |
-
label="
|
| 222 |
-
height=
|
| 223 |
)
|
| 224 |
demoiring_btn = gr.Button("开始去摩尔纹", variant="primary")
|
| 225 |
|
|
@@ -227,24 +205,19 @@ def create_interface():
|
|
| 227 |
demoiring_output = gr.Image(
|
| 228 |
type="pil",
|
| 229 |
label="去摩尔纹结果",
|
| 230 |
-
height=
|
| 231 |
)
|
| 232 |
|
| 233 |
-
demoiring_btn.click(
|
| 234 |
-
fn=demoiring_image,
|
| 235 |
-
inputs=[demoiring_input],
|
| 236 |
-
outputs=[demoiring_output]
|
| 237 |
-
)
|
| 238 |
|
| 239 |
-
#
|
| 240 |
-
with gr.TabItem("🏃
|
| 241 |
-
gr.Markdown("### 上传运动模糊的图像,模型将去除运动模糊效应")
|
| 242 |
with gr.Row():
|
| 243 |
with gr.Column():
|
| 244 |
deblur_input = gr.Image(
|
| 245 |
type="pil",
|
| 246 |
-
label="
|
| 247 |
-
height=
|
| 248 |
)
|
| 249 |
deblur_btn = gr.Button("开始去模糊", variant="primary")
|
| 250 |
|
|
@@ -252,54 +225,43 @@ def create_interface():
|
|
| 252 |
deblur_output = gr.Image(
|
| 253 |
type="pil",
|
| 254 |
label="去模糊结果",
|
| 255 |
-
height=
|
| 256 |
)
|
| 257 |
|
| 258 |
-
deblur_btn.click(
|
| 259 |
-
fn=deblur_image,
|
| 260 |
-
inputs=[deblur_input],
|
| 261 |
-
outputs=[deblur_output]
|
| 262 |
-
)
|
| 263 |
|
| 264 |
-
# 使用说明
|
| 265 |
with gr.Accordion("使用说明", open=False):
|
| 266 |
gr.Markdown("""
|
| 267 |
-
###
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
-
|
| 271 |
-
-
|
| 272 |
-
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
-
|
| 277 |
-
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
-
|
| 282 |
-
-
|
| 283 |
-
|
| 284 |
-
**注意事项:**
|
| 285 |
-
- 支持 PNG、JPG 等常见图像格式
|
| 286 |
-
- 图像会自动调整尺寸以适应模型要求
|
| 287 |
-
- 处理时间根据图像大小和复杂度而定
|
| 288 |
-
- 建议图像尺寸不超过 2048x2048 像素
|
| 289 |
""")
|
| 290 |
|
| 291 |
-
# 底部信息
|
| 292 |
gr.Markdown("""
|
| 293 |
---
|
| 294 |
-
|
| 295 |
-
**部署平台:** Hugging Face Spaces
|
| 296 |
-
**开发者:** yssszzzzzzzzy
|
| 297 |
""")
|
| 298 |
|
| 299 |
return demo
|
| 300 |
|
| 301 |
|
| 302 |
if __name__ == "__main__":
|
| 303 |
-
print("启动
|
| 304 |
demo = create_interface()
|
|
|
|
| 305 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
| 2 |
import traceback
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
from model import init, dehaze_inference, demoiring_inference, deblur_inference, get_model_status
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
| 8 |
|
| 9 |
+
# HF Spaces环境检测
|
| 10 |
+
print("🚀 HuggingFace Spaces 环境检测:")
|
| 11 |
+
print(f"CPU核心数: {os.cpu_count()}")
|
| 12 |
+
|
| 13 |
|
| 14 |
def safe_init():
|
| 15 |
+
"""安全初始化"""
|
| 16 |
try:
|
| 17 |
+
print("初始化模型...")
|
| 18 |
init()
|
| 19 |
+
print("✅ 初始化成功")
|
| 20 |
return True
|
| 21 |
except Exception as e:
|
| 22 |
+
print(f"❌ 初始化失败: {e}")
|
|
|
|
| 23 |
return False
|
| 24 |
|
| 25 |
|
| 26 |
+
# 初始化
|
| 27 |
model_loaded = safe_init()
|
| 28 |
|
| 29 |
|
| 30 |
+
def create_warning_image(message):
|
| 31 |
+
"""创建警告图像"""
|
| 32 |
+
img = Image.new('RGB', (512, 256), color='#f0f0f0')
|
| 33 |
+
return img
|
|
|
|
|
|
|
| 34 |
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
def process_with_timeout(inference_func, image_bytes, timeout=30):
|
| 37 |
+
"""带超时的处理函数"""
|
| 38 |
+
import signal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
def timeout_handler(signum, frame):
|
| 41 |
+
raise TimeoutError("处理超时")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
# 设置超时
|
| 44 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 45 |
+
signal.alarm(timeout)
|
|
|
|
| 46 |
|
| 47 |
+
try:
|
| 48 |
+
result = inference_func(image_bytes)
|
| 49 |
+
signal.alarm(0) # 取消超时
|
| 50 |
+
return result
|
| 51 |
+
except TimeoutError:
|
| 52 |
+
signal.alarm(0)
|
| 53 |
+
raise Exception(f"处理超时(>{timeout}秒)")
|
| 54 |
+
except Exception as e:
|
| 55 |
+
signal.alarm(0)
|
| 56 |
+
raise e
|
| 57 |
|
|
|
|
| 58 |
|
| 59 |
+
def process_image(input_image, task_type):
|
| 60 |
+
"""统一处理函数"""
|
| 61 |
+
if input_image is None:
|
| 62 |
+
return create_warning_image("请上传图像")
|
| 63 |
|
|
|
|
|
|
|
| 64 |
try:
|
| 65 |
+
start_time = time.time()
|
| 66 |
+
print(f"开始处理: {task_type}")
|
|
|
|
| 67 |
|
| 68 |
+
# 检查模型状态
|
| 69 |
+
status = get_model_status()
|
| 70 |
+
if not status.get(f"{task_type}_model_loaded", False):
|
| 71 |
+
return create_warning_image(f"{task_type}模型未加载")
|
|
|
|
| 72 |
|
| 73 |
+
# 图像预检查
|
| 74 |
+
if input_image.size[0] * input_image.size[1] > 1024 * 1024:
|
| 75 |
+
print("⚠️ 图像过大,将自动缩小")
|
| 76 |
|
| 77 |
+
# 转换为字节流
|
| 78 |
+
buf = BytesIO()
|
| 79 |
+
input_image.save(buf, format='PNG')
|
| 80 |
+
image_bytes = buf.getvalue()
|
| 81 |
|
| 82 |
+
# 选择推理函数
|
| 83 |
+
inference_map = {
|
| 84 |
+
"dehaze": dehaze_inference,
|
| 85 |
+
"demoiring": demoiring_inference,
|
| 86 |
+
"deblur": deblur_inference
|
| 87 |
+
}
|
| 88 |
|
| 89 |
+
inference_func = inference_map[task_type]
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
# 执行推理(带超时保护)
|
| 92 |
+
try:
|
| 93 |
+
result_bytes = process_with_timeout(inference_func, image_bytes, timeout=25)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
if "超时" in str(e):
|
| 96 |
+
print("⚠️ 处理超时,尝试降低图像质量...")
|
| 97 |
+
# 缩小图像重试
|
| 98 |
+
small_image = input_image.resize(
|
| 99 |
+
(input_image.size[0] // 2, input_image.size[1] // 2),
|
| 100 |
+
Image.LANCZOS
|
| 101 |
+
)
|
| 102 |
+
buf = BytesIO()
|
| 103 |
+
small_image.save(buf, format='PNG')
|
| 104 |
+
image_bytes = buf.getvalue()
|
| 105 |
+
result_bytes = process_with_timeout(inference_func, image_bytes, timeout=20)
|
| 106 |
+
else:
|
| 107 |
+
raise e
|
| 108 |
+
|
| 109 |
+
# 转换结果
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
result_image = Image.open(BytesIO(result_bytes))
|
| 111 |
+
|
| 112 |
+
elapsed = time.time() - start_time
|
| 113 |
+
print(f"✅ 处理完成: {elapsed:.2f}秒")
|
| 114 |
|
| 115 |
return result_image
|
| 116 |
|
| 117 |
except Exception as e:
|
| 118 |
+
error_msg = f"处理失败: {str(e)}"
|
| 119 |
+
print(f"❌ {error_msg}")
|
| 120 |
+
return create_warning_image(error_msg)
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
+
# 创建处理函数
|
| 124 |
+
def dehaze_process(image):
|
| 125 |
+
return process_image(image, "dehaze")
|
| 126 |
|
| 127 |
|
| 128 |
+
def demoiring_process(image):
|
| 129 |
+
return process_image(image, "demoiring")
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
+
def deblur_process(image):
|
| 133 |
+
return process_image(image, "deblur")
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
+
def get_system_info():
|
| 137 |
+
"""获取系统信息"""
|
| 138 |
status = get_model_status()
|
| 139 |
+
info = "🖥️ HuggingFace Spaces (免费版)\n"
|
| 140 |
+
info += f"💾 内存: 有限\n"
|
| 141 |
+
info += f"⚡ CPU: {os.cpu_count()}核\n\n"
|
| 142 |
+
info += "📊 模型状态:\n"
|
| 143 |
+
info += f"🌫️ 去雾: {'✅' if status['dehaze_model_loaded'] else '❌'}\n"
|
| 144 |
+
info += f"🔍 去摩尔纹: {'✅' if status['demoiring_model_loaded'] else '❌'}\n"
|
| 145 |
+
info += f"🏃 去模糊: {'✅' if status['deblur_model_loaded'] else '❌'}\n\n"
|
| 146 |
+
info += "⚠️ 注意: 大图像会自动缩小以适应免费版限制"
|
| 147 |
return info
|
| 148 |
|
| 149 |
|
| 150 |
+
# 创建界面
|
| 151 |
def create_interface():
|
| 152 |
+
# 使用简单主题以减少资源占用
|
| 153 |
+
with gr.Blocks(
|
| 154 |
+
title="FPro图像处理 - HF Spaces版",
|
| 155 |
+
theme=gr.themes.Default(),
|
| 156 |
+
css=".gradio-container {max-width: 100% !important}"
|
| 157 |
+
) as demo:
|
| 158 |
+
gr.Markdown("# 🚀 FPro 图像处理工具")
|
| 159 |
+
gr.Markdown("**HuggingFace Spaces 免费版 - 极速优化**")
|
| 160 |
+
|
| 161 |
+
# 系统信息
|
| 162 |
with gr.Row():
|
| 163 |
+
system_info = gr.Textbox(
|
| 164 |
+
label="系统信息",
|
| 165 |
+
value=get_system_info(),
|
| 166 |
interactive=False,
|
| 167 |
+
lines=6,
|
| 168 |
+
max_lines=6
|
| 169 |
)
|
| 170 |
|
| 171 |
+
# 主要功能区
|
| 172 |
with gr.Tabs():
|
| 173 |
+
# 去雾
|
| 174 |
+
with gr.TabItem("🌫️ 去雾"):
|
|
|
|
| 175 |
with gr.Row():
|
| 176 |
with gr.Column():
|
| 177 |
dehaze_input = gr.Image(
|
| 178 |
type="pil",
|
| 179 |
label="上传有雾图像",
|
| 180 |
+
height=300
|
| 181 |
)
|
| 182 |
dehaze_btn = gr.Button("开始去雾", variant="primary")
|
| 183 |
|
|
|
|
| 185 |
dehaze_output = gr.Image(
|
| 186 |
type="pil",
|
| 187 |
label="去雾结果",
|
| 188 |
+
height=300
|
| 189 |
)
|
| 190 |
|
| 191 |
+
dehaze_btn.click(dehaze_process, dehaze_input, dehaze_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
# 去摩尔纹
|
| 194 |
with gr.TabItem("🔍 去摩尔纹"):
|
|
|
|
| 195 |
with gr.Row():
|
| 196 |
with gr.Column():
|
| 197 |
demoiring_input = gr.Image(
|
| 198 |
type="pil",
|
| 199 |
+
label="上传摩尔纹图像",
|
| 200 |
+
height=300
|
| 201 |
)
|
| 202 |
demoiring_btn = gr.Button("开始去摩尔纹", variant="primary")
|
| 203 |
|
|
|
|
| 205 |
demoiring_output = gr.Image(
|
| 206 |
type="pil",
|
| 207 |
label="去摩尔纹结果",
|
| 208 |
+
height=300
|
| 209 |
)
|
| 210 |
|
| 211 |
+
demoiring_btn.click(demoiring_process, demoiring_input, demoiring_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
# 去模糊
|
| 214 |
+
with gr.TabItem("🏃 去模糊"):
|
|
|
|
| 215 |
with gr.Row():
|
| 216 |
with gr.Column():
|
| 217 |
deblur_input = gr.Image(
|
| 218 |
type="pil",
|
| 219 |
+
label="上传模糊图像",
|
| 220 |
+
height=300
|
| 221 |
)
|
| 222 |
deblur_btn = gr.Button("开始去模糊", variant="primary")
|
| 223 |
|
|
|
|
| 225 |
deblur_output = gr.Image(
|
| 226 |
type="pil",
|
| 227 |
label="去模糊结果",
|
| 228 |
+
height=300
|
| 229 |
)
|
| 230 |
|
| 231 |
+
deblur_btn.click(deblur_process, deblur_input, deblur_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
+
# 使用说明 - 简化版
|
| 234 |
with gr.Accordion("使用说明", open=False):
|
| 235 |
gr.Markdown("""
|
| 236 |
+
### ⚡ HF Spaces 优化说明
|
| 237 |
+
|
| 238 |
+
**🎯 自动优化特性:**
|
| 239 |
+
- 图像自动缩放至512px以内
|
| 240 |
+
- 智能切块减少内存使用
|
| 241 |
+
- 超时保护防止崩溃
|
| 242 |
+
- 内存自动清理
|
| 243 |
+
|
| 244 |
+
**⚠️ 免费版限制:**
|
| 245 |
+
- 处理时间: 约20-30秒
|
| 246 |
+
- 图像尺寸: 自动缩放
|
| 247 |
+
- 并发处理: 不支持
|
| 248 |
+
|
| 249 |
+
**💡 使用建议:**
|
| 250 |
+
- 上传清晰的原图
|
| 251 |
+
- 耐心等待处理完成
|
| 252 |
+
- 大图会自动缩小
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
""")
|
| 254 |
|
|
|
|
| 255 |
gr.Markdown("""
|
| 256 |
---
|
| 257 |
+
**基于 FPro (Frequency Prompted Transformer) | HuggingFace Spaces 优化版**
|
|
|
|
|
|
|
| 258 |
""")
|
| 259 |
|
| 260 |
return demo
|
| 261 |
|
| 262 |
|
| 263 |
if __name__ == "__main__":
|
| 264 |
+
print("🌐 启动 HuggingFace Spaces 应用...")
|
| 265 |
demo = create_interface()
|
| 266 |
+
demo.queue(concurrency_count=1, max_size=3) # 限制并发
|
| 267 |
demo.launch()
|
model.py
CHANGED
|
@@ -1,262 +1,307 @@
|
|
| 1 |
-
# model.py -
|
| 2 |
import yaml, torch, math, numpy as np
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from PIL import Image
|
| 5 |
from io import BytesIO
|
| 6 |
from basicsr.models.archs.FPro_arch import FPro
|
|
|
|
| 7 |
|
| 8 |
-
#
|
|
|
|
|
|
|
| 9 |
device = torch.device('cpu')
|
| 10 |
|
| 11 |
-
#
|
| 12 |
dehaze_model = None
|
| 13 |
demoiring_model = None
|
| 14 |
deblur_model = None
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
_, C, H, W = imgtensor.shape
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
split_data = []
|
| 30 |
-
|
| 31 |
-
for ws in wstarts:
|
| 32 |
-
cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
|
| 33 |
-
starts.append((hs, ws))
|
| 34 |
-
split_data.append(cimgdata)
|
| 35 |
-
return split_data, starts
|
| 36 |
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
for w in range(W):
|
| 47 |
-
score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
|
| 48 |
-
return score
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
|
| 54 |
-
tot_score = torch.zeros((B, C, H, W))
|
| 55 |
-
merge_img = torch.zeros((B, C, H, W))
|
| 56 |
-
scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
|
| 57 |
|
| 58 |
-
for simg, cstart in zip(split_data, starts):
|
| 59 |
-
hs, ws = cstart
|
| 60 |
-
merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
|
| 61 |
-
tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
global dehaze_model, demoiring_model, deblur_model
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
print(f"去雾模型加载失败: {e}")
|
| 85 |
-
dehaze_model = None
|
| 86 |
|
| 87 |
-
# 初始化去摩尔纹模型
|
| 88 |
-
try:
|
| 89 |
-
print("正在加载去摩尔纹模型...")
|
| 90 |
-
# 尝试加载YAML配置文件
|
| 91 |
-
try:
|
| 92 |
-
demoiring_cfg = yaml.safe_load(open("./option/RealDemoiring_FPro.yml", "r"))['network_g']
|
| 93 |
-
demoiring_cfg.pop('type', None)
|
| 94 |
-
except FileNotFoundError:
|
| 95 |
-
# 如果没有单独的配置文件,使用默认配置
|
| 96 |
-
print("未找到去摩尔纹配置文件,使用默认配置")
|
| 97 |
-
demoiring_cfg = {
|
| 98 |
-
'inp_channels': 3,
|
| 99 |
-
'out_channels': 3,
|
| 100 |
-
'dim': 48,
|
| 101 |
-
'num_blocks': [4, 6, 6, 8],
|
| 102 |
-
'num_refinement_blocks': 4,
|
| 103 |
-
'heads': [1, 2, 4, 8],
|
| 104 |
-
'ffn_expansion_factor': 2.66,
|
| 105 |
-
'bias': False,
|
| 106 |
-
'LayerNorm_type': 'WithBias',
|
| 107 |
-
'dual_pixel_task': False
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
demoiring_model = FPro(**demoiring_cfg)
|
| 111 |
-
demoiring_model = demoiring_model.to(device)
|
| 112 |
-
demoiring_ckpt = torch.load("./model/demoire_noAug.pth", map_location='cpu')
|
| 113 |
-
demoiring_model.load_state_dict(demoiring_ckpt['params'])
|
| 114 |
-
demoiring_model.eval()
|
| 115 |
-
demoiring_model = demoiring_model.cpu()
|
| 116 |
-
print("去摩尔纹模型加载成功!")
|
| 117 |
-
except Exception as e:
|
| 118 |
-
print(f"去摩尔纹模型加载失败: {e}")
|
| 119 |
-
demoiring_model = None
|
| 120 |
|
| 121 |
-
|
|
|
|
| 122 |
try:
|
| 123 |
-
|
| 124 |
-
# 尝试加载YAML配置文件
|
| 125 |
try:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
print("
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
except Exception as e:
|
| 152 |
-
print(f"
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def inference(body: bytes, task_type: str = "dehaze") -> bytes:
|
| 157 |
-
"""
|
| 158 |
-
推理函数:支持去雾、去摩尔纹和运动去模糊
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
body: 图像字节流
|
| 162 |
-
task_type: 任务类型,"dehaze"、"demoiring" 或 "deblur"
|
| 163 |
-
|
| 164 |
-
Returns:
|
| 165 |
-
处理后的图像字节流
|
| 166 |
-
"""
|
| 167 |
-
# 选择对应的模型和参数
|
| 168 |
-
if task_type == "dehaze":
|
| 169 |
-
if dehaze_model is None:
|
| 170 |
-
raise Exception("去雾模型未加载")
|
| 171 |
-
model = dehaze_model
|
| 172 |
-
# 去雾任务的参数
|
| 173 |
-
crop_size_arg = 256
|
| 174 |
-
overlap_size_arg = 158
|
| 175 |
-
elif task_type == "demoiring":
|
| 176 |
-
if demoiring_model is None:
|
| 177 |
-
raise Exception("去摩尔纹模型未加载")
|
| 178 |
-
model = demoiring_model
|
| 179 |
-
# 去摩尔纹任务的参数
|
| 180 |
-
crop_size_arg = 256
|
| 181 |
-
overlap_size_arg = 200
|
| 182 |
-
elif task_type == "deblur":
|
| 183 |
-
if deblur_model is None:
|
| 184 |
-
raise Exception("运动去模糊模型未加载")
|
| 185 |
-
model = deblur_model
|
| 186 |
-
# 运动去模糊任务的参数
|
| 187 |
-
crop_size_arg = 256
|
| 188 |
-
overlap_size_arg = 200
|
| 189 |
-
else:
|
| 190 |
-
raise Exception(f"不支持的任务类型: {task_type}")
|
| 191 |
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
img = Image.open(BytesIO(body)).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
arr = np.float32(img) / 255.0
|
| 195 |
t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
t = t.to(device)
|
| 199 |
-
|
| 200 |
-
# Padding in case images are not multiples of 8
|
| 201 |
-
factor = 8
|
| 202 |
h, w = t.shape[2], t.shape[3]
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
|
| 209 |
B, C, H, W = t.shape
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
|
|
|
| 214 |
restored = model(t)
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
for i,
|
| 222 |
-
|
| 223 |
-
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
restored = restored.to(device)
|
| 228 |
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
| 230 |
restored = restored[:, :, :h, :w]
|
| 231 |
|
| 232 |
-
#
|
| 233 |
-
|
| 234 |
-
|
| 235 |
|
| 236 |
-
# 输出
|
| 237 |
-
out_img = Image.fromarray(
|
| 238 |
buf = BytesIO()
|
| 239 |
-
out_img.save(buf, format="PNG")
|
|
|
|
|
|
|
| 240 |
return buf.getvalue()
|
| 241 |
|
| 242 |
|
|
|
|
| 243 |
def dehaze_inference(body: bytes) -> bytes:
|
| 244 |
-
""
|
| 245 |
-
return inference(body, task_type="dehaze")
|
| 246 |
|
| 247 |
|
| 248 |
def demoiring_inference(body: bytes) -> bytes:
|
| 249 |
-
""
|
| 250 |
-
return inference(body, task_type="demoiring")
|
| 251 |
|
| 252 |
|
| 253 |
def deblur_inference(body: bytes) -> bytes:
|
| 254 |
-
""
|
| 255 |
-
return inference(body, task_type="deblur")
|
| 256 |
|
| 257 |
|
| 258 |
def get_model_status():
|
| 259 |
-
"""获取模型加载状态"""
|
| 260 |
return {
|
| 261 |
"dehaze_model_loaded": dehaze_model is not None,
|
| 262 |
"demoiring_model_loaded": demoiring_model is not None,
|
|
|
|
| 1 |
+
# model.py - HuggingFace Spaces极限优化版
|
| 2 |
import yaml, torch, math, numpy as np
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from PIL import Image
|
| 5 |
from io import BytesIO
|
| 6 |
from basicsr.models.archs.FPro_arch import FPro
|
| 7 |
+
import gc
|
| 8 |
|
| 9 |
+
# HF Spaces环境优化设置
|
| 10 |
+
torch.set_num_threads(2) # HF Spaces通常是2核CPU
|
| 11 |
+
torch.set_num_interop_threads(1)
|
| 12 |
device = torch.device('cpu')
|
| 13 |
|
| 14 |
+
# 全局模型变量
|
| 15 |
dehaze_model = None
|
| 16 |
demoiring_model = None
|
| 17 |
deblur_model = None
|
| 18 |
|
| 19 |
+
# 极限优化配置
|
| 20 |
+
EXTREME_OPTIMIZATION = {
|
| 21 |
+
'max_image_size': 512, # 严格限制最大图像尺寸
|
| 22 |
+
'min_tile_size': 128, # 最小切块尺寸
|
| 23 |
+
'max_tiles': 4, # 最大切块数量 (2x2)
|
| 24 |
+
'overlap_ratio': 0.1, # 最小重叠比例
|
| 25 |
+
'memory_limit_mb': 1024, # 内存限制
|
| 26 |
+
'enable_half_precision': False, # HF Spaces上通常不支持
|
| 27 |
+
}
|
| 28 |
|
| 29 |
+
|
| 30 |
+
def aggressive_image_resize(image, max_size=512):
|
| 31 |
+
"""激进的图像缩放策略"""
|
| 32 |
+
w, h = image.size
|
| 33 |
+
|
| 34 |
+
# 如果图像太大,强制缩小
|
| 35 |
+
if max(w, h) > max_size:
|
| 36 |
+
scale = max_size / max(w, h)
|
| 37 |
+
new_w, new_h = int(w * scale), int(h * scale)
|
| 38 |
+
|
| 39 |
+
# 确保是8的倍数,但优先保证小尺寸
|
| 40 |
+
new_w = ((new_w + 7) // 8) * 8
|
| 41 |
+
new_h = ((new_h + 7) // 8) * 8
|
| 42 |
+
|
| 43 |
+
# 如果调整后还是太大,再次缩小
|
| 44 |
+
if max(new_w, new_h) > max_size:
|
| 45 |
+
new_w = min(new_w, max_size)
|
| 46 |
+
new_h = min(new_h, max_size)
|
| 47 |
+
new_w = (new_w // 8) * 8
|
| 48 |
+
new_h = (new_h // 8) * 8
|
| 49 |
+
|
| 50 |
+
image = image.resize((new_w, new_h), Image.LANCZOS)
|
| 51 |
+
print(f"图像已缩放: {w}x{h} -> {new_w}x{new_h}")
|
| 52 |
+
|
| 53 |
+
return image
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def minimal_splitimage(imgtensor, max_tiles=4):
|
| 57 |
+
"""最小化切块策略"""
|
| 58 |
_, C, H, W = imgtensor.shape
|
| 59 |
+
|
| 60 |
+
# 如果图像小于256x256,直接返回
|
| 61 |
+
if H <= 256 and W <= 256:
|
| 62 |
+
return [imgtensor], [(0, 0)]
|
| 63 |
+
|
| 64 |
+
# 计算最优切块配置以限制切块数量
|
| 65 |
+
if max_tiles == 1:
|
| 66 |
+
return [imgtensor], [(0, 0)]
|
| 67 |
+
elif max_tiles == 4: # 2x2
|
| 68 |
+
crop_h = H // 2 if H > 256 else H
|
| 69 |
+
crop_w = W // 2 if W > 256 else W
|
| 70 |
+
overlap = min(32, crop_h // 8, crop_w // 8)
|
| 71 |
+
else:
|
| 72 |
+
# 默认单切块
|
| 73 |
+
return [imgtensor], [(0, 0)]
|
| 74 |
+
|
| 75 |
+
# 简化的2x2切块
|
| 76 |
split_data = []
|
| 77 |
+
starts = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
h_mid = H // 2
|
| 80 |
+
w_mid = W // 2
|
| 81 |
|
| 82 |
+
# 四个切块的位置
|
| 83 |
+
positions = [
|
| 84 |
+
(0, 0, min(crop_h + overlap, H), min(crop_w + overlap, W)),
|
| 85 |
+
(0, max(0, w_mid - overlap), min(crop_h + overlap, H), W),
|
| 86 |
+
(max(0, h_mid - overlap), 0, H, min(crop_w + overlap, W)),
|
| 87 |
+
(max(0, h_mid - overlap), max(0, w_mid - overlap), H, W)
|
| 88 |
+
]
|
| 89 |
|
| 90 |
+
for h_start, w_start, h_end, w_end in positions:
|
| 91 |
+
if h_end <= h_start or w_end <= w_start:
|
| 92 |
+
continue
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
patch = imgtensor[:, :, h_start:h_end, w_start:w_end]
|
| 95 |
+
split_data.append(patch)
|
| 96 |
+
starts.append((h_start, w_start))
|
| 97 |
|
| 98 |
+
print(f"图像分割为 {len(split_data)} 个切块")
|
| 99 |
+
return split_data, starts
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
def fast_merge(split_data, starts, resolution):
|
| 103 |
+
"""快速合并,避免复杂权重计算"""
|
| 104 |
+
B, C, H, W = resolution
|
| 105 |
+
result = torch.zeros((B, C, H, W))
|
| 106 |
+
count = torch.zeros((B, C, H, W))
|
| 107 |
|
| 108 |
+
for patch, (h_start, w_start) in zip(split_data, starts):
|
| 109 |
+
patch_h, patch_w = patch.shape[2], patch.shape[3]
|
| 110 |
+
h_end = min(h_start + patch_h, H)
|
| 111 |
+
w_end = min(w_start + patch_w, W)
|
| 112 |
|
| 113 |
+
actual_h = h_end - h_start
|
| 114 |
+
actual_w = w_end - w_start
|
|
|
|
| 115 |
|
| 116 |
+
result[:, :, h_start:h_end, w_start:w_end] += patch[:, :, :actual_h, :actual_w]
|
| 117 |
+
count[:, :, h_start:h_end, w_start:w_end] += 1
|
| 118 |
+
|
| 119 |
+
# 避免除零
|
| 120 |
+
result = result / torch.clamp(count, min=1)
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def memory_cleanup():
|
| 125 |
+
"""内存清理"""
|
| 126 |
+
gc.collect()
|
| 127 |
+
if torch.cuda.is_available():
|
| 128 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
def load_model_minimal(config_path, weight_path, default_config):
|
| 132 |
+
"""最小化模型加载"""
|
| 133 |
try:
|
| 134 |
+
# 尝试加载配置
|
|
|
|
| 135 |
try:
|
| 136 |
+
with open(config_path, 'r') as f:
|
| 137 |
+
cfg = yaml.safe_load(f)['network_g']
|
| 138 |
+
cfg.pop('type', None)
|
| 139 |
+
except:
|
| 140 |
+
print(f"使用默认配置: {config_path}")
|
| 141 |
+
cfg = default_config
|
| 142 |
+
|
| 143 |
+
# 创建模型
|
| 144 |
+
model = FPro(**cfg)
|
| 145 |
+
|
| 146 |
+
# 加载权重
|
| 147 |
+
checkpoint = torch.load(weight_path, map_location='cpu')
|
| 148 |
+
if 'params' in checkpoint:
|
| 149 |
+
model.load_state_dict(checkpoint['params'])
|
| 150 |
+
else:
|
| 151 |
+
model.load_state_dict(checkpoint)
|
| 152 |
+
|
| 153 |
+
model.eval()
|
| 154 |
+
|
| 155 |
+
# 尝试JIT优化(如果失败就用原始模型)
|
| 156 |
+
try:
|
| 157 |
+
sample_input = torch.randn(1, 3, 128, 128)
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
traced_model = torch.jit.trace(model, sample_input)
|
| 160 |
+
traced_model.eval()
|
| 161 |
+
print("JIT优化成功")
|
| 162 |
+
return traced_model
|
| 163 |
+
except:
|
| 164 |
+
print("JIT优化失败,使用原始模型")
|
| 165 |
+
return model
|
| 166 |
+
|
| 167 |
except Exception as e:
|
| 168 |
+
print(f"模型加载失败: {e}")
|
| 169 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
|
| 172 |
+
def init():
|
| 173 |
+
"""优化的初始化函数"""
|
| 174 |
+
global dehaze_model, demoiring_model, deblur_model
|
| 175 |
+
|
| 176 |
+
print("HuggingFace Spaces 极限优化模式启动...")
|
| 177 |
+
|
| 178 |
+
# 默认配置
|
| 179 |
+
default_config = {
|
| 180 |
+
'inp_channels': 3,
|
| 181 |
+
'out_channels': 3,
|
| 182 |
+
'dim': 48,
|
| 183 |
+
'num_blocks': [4, 6, 6, 8],
|
| 184 |
+
'num_refinement_blocks': 4,
|
| 185 |
+
'heads': [1, 2, 4, 8],
|
| 186 |
+
'ffn_expansion_factor': 2.66,
|
| 187 |
+
'bias': False,
|
| 188 |
+
'LayerNorm_type': 'WithBias',
|
| 189 |
+
'dual_pixel_task': False
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# 去雾模型
|
| 193 |
+
print("加载去雾模型...")
|
| 194 |
+
dehaze_model = load_model_minimal(
|
| 195 |
+
"./option/RealDehazing_FPro.yml",
|
| 196 |
+
"./model/synDehaze.pth",
|
| 197 |
+
default_config
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# 去摩尔纹模型
|
| 201 |
+
print("加载去摩尔纹模型...")
|
| 202 |
+
demoiring_model = load_model_minimal(
|
| 203 |
+
"./option/RealDemoiring_FPro.yml",
|
| 204 |
+
"./model/demoire_noAug.pth",
|
| 205 |
+
default_config
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# 运动去模糊模型
|
| 209 |
+
print("加载运动去模糊模型...")
|
| 210 |
+
deblur_model = load_model_minimal(
|
| 211 |
+
"./option/Deblurring_FPro.yml",
|
| 212 |
+
"./model/deblur.pth",
|
| 213 |
+
default_config
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
memory_cleanup()
|
| 217 |
+
print("模型初始化完成")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def ultra_fast_inference(body: bytes, task_type: str = "dehaze") -> bytes:
|
| 221 |
+
"""超快推理模式"""
|
| 222 |
+
# 选择模型
|
| 223 |
+
model_map = {
|
| 224 |
+
"dehaze": dehaze_model,
|
| 225 |
+
"demoiring": demoiring_model,
|
| 226 |
+
"deblur": deblur_model
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
model = model_map.get(task_type)
|
| 230 |
+
if model is None:
|
| 231 |
+
raise Exception(f"{task_type}模型未加载")
|
| 232 |
+
|
| 233 |
+
# 加载图像
|
| 234 |
img = Image.open(BytesIO(body)).convert("RGB")
|
| 235 |
+
print(f"原始尺寸: {img.size}")
|
| 236 |
+
|
| 237 |
+
# 激进缩放
|
| 238 |
+
img = aggressive_image_resize(img, EXTREME_OPTIMIZATION['max_image_size'])
|
| 239 |
+
|
| 240 |
+
# 转换为张量
|
| 241 |
arr = np.float32(img) / 255.0
|
| 242 |
t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
|
| 243 |
|
| 244 |
+
# 最小padding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
h, w = t.shape[2], t.shape[3]
|
| 246 |
+
pad_h = (8 - h % 8) % 8
|
| 247 |
+
pad_w = (8 - w % 8) % 8
|
| 248 |
+
|
| 249 |
+
if pad_h > 0 or pad_w > 0:
|
| 250 |
+
t = F.pad(t, (0, pad_w, 0, pad_h), 'reflect')
|
| 251 |
|
| 252 |
B, C, H, W = t.shape
|
| 253 |
|
| 254 |
+
# 推理
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
if H <= 256 and W <= 256:
|
| 257 |
+
# 小图像直接处理
|
| 258 |
restored = model(t)
|
| 259 |
+
else:
|
| 260 |
+
# 最小化切块处理
|
| 261 |
+
split_data, starts = minimal_splitimage(t, EXTREME_OPTIMIZATION['max_tiles'])
|
| 262 |
|
| 263 |
+
# 逐个处理切块
|
| 264 |
+
processed_data = []
|
| 265 |
+
for i, patch in enumerate(split_data):
|
| 266 |
+
if i > 0: # 每处理一个切块就清理内存
|
| 267 |
+
memory_cleanup()
|
| 268 |
|
| 269 |
+
processed = model(patch)
|
| 270 |
+
processed_data.append(processed)
|
|
|
|
| 271 |
|
| 272 |
+
# 快速合并
|
| 273 |
+
restored = fast_merge(processed_data, starts, (B, C, H, W))
|
| 274 |
+
|
| 275 |
+
# 去除padding
|
| 276 |
restored = restored[:, :, :h, :w]
|
| 277 |
|
| 278 |
+
# 快速后处理
|
| 279 |
+
result = torch.clamp(restored, 0, 1).squeeze(0).permute(1, 2, 0).numpy()
|
| 280 |
+
result = (result * 255).astype(np.uint8)
|
| 281 |
|
| 282 |
+
# 输出
|
| 283 |
+
out_img = Image.fromarray(result)
|
| 284 |
buf = BytesIO()
|
| 285 |
+
out_img.save(buf, format="PNG", optimize=True)
|
| 286 |
+
|
| 287 |
+
memory_cleanup()
|
| 288 |
return buf.getvalue()
|
| 289 |
|
| 290 |
|
| 291 |
+
# 导出函数
|
| 292 |
def dehaze_inference(body: bytes) -> bytes:
|
| 293 |
+
return ultra_fast_inference(body, "dehaze")
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
def demoiring_inference(body: bytes) -> bytes:
|
| 297 |
+
return ultra_fast_inference(body, "demoiring")
|
|
|
|
| 298 |
|
| 299 |
|
| 300 |
def deblur_inference(body: bytes) -> bytes:
|
| 301 |
+
return ultra_fast_inference(body, "deblur")
|
|
|
|
| 302 |
|
| 303 |
|
| 304 |
def get_model_status():
|
|
|
|
| 305 |
return {
|
| 306 |
"dehaze_model_loaded": dehaze_model is not None,
|
| 307 |
"demoiring_model_loaded": demoiring_model is not None,
|