yssszzzzzzzzy commited on
Commit
bd92798
·
verified ·
1 Parent(s): 64df45e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +153 -191
  2. model.py +244 -199
app.py CHANGED
@@ -1,200 +1,183 @@
1
  import gradio as gr
2
- import tempfile
3
- import os
4
  import traceback
 
 
5
  from model import init, dehaze_inference, demoiring_inference, deblur_inference, get_model_status
6
  from PIL import Image
7
  from io import BytesIO
8
 
 
 
 
 
9
 
10
  def safe_init():
11
- """安全初始化模型"""
12
  try:
13
- print("正在初始化模型...")
14
  init()
15
- print("模型初始化完成!")
16
  return True
17
  except Exception as e:
18
- print(f"模型初始化失败: {e}")
19
- print(f"详细错误: {traceback.format_exc()}")
20
  return False
21
 
22
 
23
- # 初始化模型
24
  model_loaded = safe_init()
25
 
26
 
27
- def preprocess_image(input_image):
28
- """图像预处理:调整尺寸和格式"""
29
- if input_image is None:
30
- raise Exception("输入图像为空")
31
-
32
- print(f"输入图像尺寸: {input_image.size}")
33
 
34
- # 图像预处理:确保图像尺寸合适
35
- width, height = input_image.size
36
 
37
- # 限制最小尺寸
38
- min_size = 64
39
- if width < min_size or height < min_size:
40
- scale = max(min_size / width, min_size / height)
41
- new_width = int(width * scale)
42
- new_height = int(height * scale)
43
- input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
44
- print(f"图像已放大至: {input_image.size}")
45
 
46
- # 限制最大尺寸以避免内存问题
47
- max_size = 2048
48
- if width > max_size or height > max_size:
49
- scale = min(max_size / width, max_size / height)
50
- new_width = int(width * scale)
51
- new_height = int(height * scale)
52
- input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
53
- print(f"图像已缩小至: {input_image.size}")
54
 
55
- # 确保图像尺寸是8的倍数
56
- width, height = input_image.size
57
- new_width = ((width + 7) // 8) * 8
58
- new_height = ((height + 7) // 8) * 8
59
 
60
- if new_width != width or new_height != height:
61
- input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
62
- print(f"图像已调整至8的倍数: {input_image.size}")
 
 
 
 
 
 
 
63
 
64
- return input_image
65
 
 
 
 
 
66
 
67
- def create_error_image(error_message):
68
- """创建错误提示图像"""
69
  try:
70
- from PIL import ImageDraw, ImageFont
71
- error_image = Image.new('RGB', (400, 200), color='lightgray')
72
- draw = ImageDraw.Draw(error_image)
73
 
74
- # 尝试使用默认字体
75
- try:
76
- font = ImageFont.load_default()
77
- except:
78
- font = None
79
 
80
- # 添加错误信息
81
- error_text = f"处理失败: {str(error_message)[:50]}"
82
- draw.text((10, 90), error_text, fill='red', font=font)
83
 
84
- return error_image
85
- except:
86
- # 如果创建错误图像也失败,返回简单的红色图像
87
- return Image.new('RGB', (400, 200), color='red')
88
 
 
 
 
 
 
 
89
 
90
- def process_image(input_image, task_type):
91
- """统一的图像处理函数"""
92
- try:
93
- print(f"开始处理图像,任务类型: {task_type}")
94
 
95
- # 检查模型状态
96
- model_status = get_model_status()
97
- if task_type == "dehaze" and not model_status["dehaze_model_loaded"]:
98
- raise Exception("去雾模型未加载")
99
- elif task_type == "demoiring" and not model_status["demoiring_model_loaded"]:
100
- raise Exception("去摩尔纹模型未加载")
101
- elif task_type == "deblur" and not model_status["deblur_model_loaded"]:
102
- raise Exception("运动去模糊模型未加载")
103
-
104
- # 预处理图像
105
- processed_image = preprocess_image(input_image)
106
-
107
- # 创建临时文件保存输入图像
108
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_input:
109
- processed_image.save(temp_input.name, format="PNG")
110
- temp_input_path = temp_input.name
111
-
112
- # 读取输入图像并进行推理
113
- with open(temp_input_path, "rb") as f:
114
- input_bytes = f.read()
115
- print(f"读取到 {len(input_bytes)} 字节的图像数据")
116
-
117
- print("开始推理...")
118
-
119
- # 根据任务类型选择推理函数
120
- if task_type == "dehaze":
121
- result_bytes = dehaze_inference(input_bytes)
122
- elif task_type == "demoiring":
123
- result_bytes = demoiring_inference(input_bytes)
124
- elif task_type == "deblur":
125
- result_bytes = deblur_inference(input_bytes)
126
- else:
127
- raise Exception(f"不支持的任务类型: {task_type}")
128
-
129
- print(f"推理完成,输出 {len(result_bytes)} 字节")
130
-
131
- # 清理临时文件
132
- os.unlink(temp_input_path)
133
-
134
- # 将字节流转换为PIL图像对象
135
  result_image = Image.open(BytesIO(result_bytes))
136
- print(f"输出图像尺寸: {result_image.size}")
 
 
137
 
138
  return result_image
139
 
140
  except Exception as e:
141
- error_msg = f"处理图像时出错: {e}"
142
- print(error_msg)
143
- print(f"详细错误: {traceback.format_exc()}")
144
- return create_error_image(error_msg)
145
 
146
 
147
- def dehaze_image(input_image):
148
- """去雾处理函数"""
149
- return process_image(input_image, "dehaze")
150
 
151
 
152
- def demoiring_image(input_image):
153
- """去摩尔纹处理函数"""
154
- return process_image(input_image, "demoiring")
155
 
156
 
157
- def deblur_image(input_image):
158
- """运动去模糊处理函数"""
159
- return process_image(input_image, "deblur")
160
 
161
 
162
- def get_model_info():
163
- """获取模型状态信息"""
164
  status = get_model_status()
165
- info = "模型加载状态:\n"
166
- info += f"🌫️ 去雾模型: {'✅ 已加载' if status['dehaze_model_loaded'] else '❌ 未加载'}\n"
167
- info += f"🔍 去摩尔纹模型: {'✅ 已加载' if status['demoiring_model_loaded'] else '❌ 未加载'}\n"
168
- info += f"🏃 运动去模糊模型: {'✅ 已加载' if status['deblur_model_loaded'] else '❌ 未加载'}"
 
 
 
 
169
  return info
170
 
171
 
172
- # 创建 Gradio 界面
173
  def create_interface():
174
- with gr.Blocks(title="FPro 图像处理工具") as demo:
175
- gr.Markdown("# FPro 图像处理工具")
176
- gr.Markdown("基于 FPro (Frequency Prompted Transformer) 模型的图像处理工具,支持去雾、去摩尔纹和运动去模糊功能。")
177
-
178
- # 模型状态信息
 
 
 
 
 
179
  with gr.Row():
180
- model_info = gr.Textbox(
181
- label="模型状态",
182
- value=get_model_info(),
183
  interactive=False,
184
- lines=4
 
185
  )
186
 
187
- # 创建标签页
188
  with gr.Tabs():
189
- # 去雾标签页
190
- with gr.TabItem("🌫️ 图像去雾"):
191
- gr.Markdown("### 上传有雾图像,模型将自动去除雾气")
192
  with gr.Row():
193
  with gr.Column():
194
  dehaze_input = gr.Image(
195
  type="pil",
196
  label="上传有雾图像",
197
- height=400
198
  )
199
  dehaze_btn = gr.Button("开始去雾", variant="primary")
200
 
@@ -202,24 +185,19 @@ def create_interface():
202
  dehaze_output = gr.Image(
203
  type="pil",
204
  label="去雾结果",
205
- height=400
206
  )
207
 
208
- dehaze_btn.click(
209
- fn=dehaze_image,
210
- inputs=[dehaze_input],
211
- outputs=[dehaze_output]
212
- )
213
 
214
- # 去摩尔纹标签页
215
  with gr.TabItem("🔍 去摩尔纹"):
216
- gr.Markdown("### 上传有摩尔纹的图像,模型将去除摩尔纹效应")
217
  with gr.Row():
218
  with gr.Column():
219
  demoiring_input = gr.Image(
220
  type="pil",
221
- label="上传有摩尔纹图像",
222
- height=400
223
  )
224
  demoiring_btn = gr.Button("开始去摩尔纹", variant="primary")
225
 
@@ -227,24 +205,19 @@ def create_interface():
227
  demoiring_output = gr.Image(
228
  type="pil",
229
  label="去摩尔纹结果",
230
- height=400
231
  )
232
 
233
- demoiring_btn.click(
234
- fn=demoiring_image,
235
- inputs=[demoiring_input],
236
- outputs=[demoiring_output]
237
- )
238
 
239
- # 运动去模糊标签页
240
- with gr.TabItem("🏃 运动去模糊"):
241
- gr.Markdown("### 上传运动模糊的图像,模型将去除运动模糊效应")
242
  with gr.Row():
243
  with gr.Column():
244
  deblur_input = gr.Image(
245
  type="pil",
246
- label="上传运动模糊图像",
247
- height=400
248
  )
249
  deblur_btn = gr.Button("开始去模糊", variant="primary")
250
 
@@ -252,54 +225,43 @@ def create_interface():
252
  deblur_output = gr.Image(
253
  type="pil",
254
  label="去模糊结果",
255
- height=400
256
  )
257
 
258
- deblur_btn.click(
259
- fn=deblur_image,
260
- inputs=[deblur_input],
261
- outputs=[deblur_output]
262
- )
263
 
264
- # 使用说明
265
  with gr.Accordion("使用说明", open=False):
266
  gr.Markdown("""
267
- ### 📋 使用指南
268
-
269
- **去雾功能:**
270
- - 适用于雾霾、雾气等造成的图像模糊
271
- - 支持各种室外场景的去雾处理
272
- - 建议上传清晰度较高的原始图像
273
-
274
- **去摩尔纹功能:**
275
- - 适用于屏幕拍摄、扫描等产生的摩尔纹
276
- - 可以处理各种周期性干扰模式
277
- - 特别适合处理数字屏幕拍摄的图像
278
-
279
- **运动去模糊功能:**
280
- - 适用于拍摄时相机抖动或物体运动造成的模糊
281
- - 可以恢复运动模糊图像的清晰度
282
- - 适合处理快速运动场景的模糊图像
283
-
284
- **注意事项:**
285
- - 支持 PNG、JPG 等常见图像格式
286
- - 图像会自动调整尺寸以适应模型要求
287
- - 处理时间根据图像大小和复杂度而定
288
- - 建议图像尺寸不超过 2048x2048 像素
289
  """)
290
 
291
- # 底部信息
292
  gr.Markdown("""
293
  ---
294
- **技术支持:** 基于 FPro (Frequency Prompted Transformer) 架构
295
- **部署平台:** Hugging Face Spaces
296
- **开发者:** yssszzzzzzzzy
297
  """)
298
 
299
  return demo
300
 
301
 
302
  if __name__ == "__main__":
303
- print("启动 Gradio 应用...")
304
  demo = create_interface()
 
305
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import traceback
3
+ import time
4
+ import os
5
  from model import init, dehaze_inference, demoiring_inference, deblur_inference, get_model_status
6
  from PIL import Image
7
  from io import BytesIO
8
 
9
+ # HF Spaces环境检测
10
+ print("🚀 HuggingFace Spaces 环境检测:")
11
+ print(f"CPU核心数: {os.cpu_count()}")
12
+
13
 
14
  def safe_init():
15
+ """安全初始化"""
16
  try:
17
+ print("初始化模型...")
18
  init()
19
+ print("✅ 初始化成功")
20
  return True
21
  except Exception as e:
22
+ print(f" 初始化失败: {e}")
 
23
  return False
24
 
25
 
26
+ # 初始化
27
  model_loaded = safe_init()
28
 
29
 
30
+ def create_warning_image(message):
31
+ """创建警告图像"""
32
+ img = Image.new('RGB', (512, 256), color='#f0f0f0')
33
+ return img
 
 
34
 
 
 
35
 
36
+ def process_with_timeout(inference_func, image_bytes, timeout=30):
37
+ """带超时的处理函数"""
38
+ import signal
 
 
 
 
 
39
 
40
+ def timeout_handler(signum, frame):
41
+ raise TimeoutError("处理超时")
 
 
 
 
 
 
42
 
43
+ # 设置超时
44
+ signal.signal(signal.SIGALRM, timeout_handler)
45
+ signal.alarm(timeout)
 
46
 
47
+ try:
48
+ result = inference_func(image_bytes)
49
+ signal.alarm(0) # 取消超时
50
+ return result
51
+ except TimeoutError:
52
+ signal.alarm(0)
53
+ raise Exception(f"处理超时(>{timeout}秒)")
54
+ except Exception as e:
55
+ signal.alarm(0)
56
+ raise e
57
 
 
58
 
59
+ def process_image(input_image, task_type):
60
+ """统一处理函数"""
61
+ if input_image is None:
62
+ return create_warning_image("请上传图像")
63
 
 
 
64
  try:
65
+ start_time = time.time()
66
+ print(f"开始处理: {task_type}")
 
67
 
68
+ # 检查模型状态
69
+ status = get_model_status()
70
+ if not status.get(f"{task_type}_model_loaded", False):
71
+ return create_warning_image(f"{task_type}模型未加载")
 
72
 
73
+ # 图像预检查
74
+ if input_image.size[0] * input_image.size[1] > 1024 * 1024:
75
+ print("⚠️ 图像过大,将自动缩小")
76
 
77
+ # 转换为字节流
78
+ buf = BytesIO()
79
+ input_image.save(buf, format='PNG')
80
+ image_bytes = buf.getvalue()
81
 
82
+ # 选择推理函数
83
+ inference_map = {
84
+ "dehaze": dehaze_inference,
85
+ "demoiring": demoiring_inference,
86
+ "deblur": deblur_inference
87
+ }
88
 
89
+ inference_func = inference_map[task_type]
 
 
 
90
 
91
+ # 执行推理(带超时保护)
92
+ try:
93
+ result_bytes = process_with_timeout(inference_func, image_bytes, timeout=25)
94
+ except Exception as e:
95
+ if "超时" in str(e):
96
+ print("⚠️ 处理超时,尝试降低图像质量...")
97
+ # 缩小图像重试
98
+ small_image = input_image.resize(
99
+ (input_image.size[0] // 2, input_image.size[1] // 2),
100
+ Image.LANCZOS
101
+ )
102
+ buf = BytesIO()
103
+ small_image.save(buf, format='PNG')
104
+ image_bytes = buf.getvalue()
105
+ result_bytes = process_with_timeout(inference_func, image_bytes, timeout=20)
106
+ else:
107
+ raise e
108
+
109
+ # 转换结果
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  result_image = Image.open(BytesIO(result_bytes))
111
+
112
+ elapsed = time.time() - start_time
113
+ print(f"✅ 处理完成: {elapsed:.2f}秒")
114
 
115
  return result_image
116
 
117
  except Exception as e:
118
+ error_msg = f"处理失败: {str(e)}"
119
+ print(f"❌ {error_msg}")
120
+ return create_warning_image(error_msg)
 
121
 
122
 
123
+ # 创建处理函数
124
+ def dehaze_process(image):
125
+ return process_image(image, "dehaze")
126
 
127
 
128
+ def demoiring_process(image):
129
+ return process_image(image, "demoiring")
 
130
 
131
 
132
+ def deblur_process(image):
133
+ return process_image(image, "deblur")
 
134
 
135
 
136
+ def get_system_info():
137
+ """获取系统信息"""
138
  status = get_model_status()
139
+ info = "🖥️ HuggingFace Spaces (免费版)\n"
140
+ info += f"💾 内存: 有限\n"
141
+ info += f" CPU: {os.cpu_count()}核\n\n"
142
+ info += "📊 模型状态:\n"
143
+ info += f"🌫️ 去雾: {'✅' if status['dehaze_model_loaded'] else '❌'}\n"
144
+ info += f"🔍 去摩尔纹: {'✅' if status['demoiring_model_loaded'] else '❌'}\n"
145
+ info += f"🏃 去模糊: {'✅' if status['deblur_model_loaded'] else '❌'}\n\n"
146
+ info += "⚠️ 注意: 大图像会自动缩小以适应免费版限制"
147
  return info
148
 
149
 
150
+ # 创建界面
151
  def create_interface():
152
+ # 使用简单主题以减少资源占用
153
+ with gr.Blocks(
154
+ title="FPro图像处理 - HF Spaces版",
155
+ theme=gr.themes.Default(),
156
+ css=".gradio-container {max-width: 100% !important}"
157
+ ) as demo:
158
+ gr.Markdown("# 🚀 FPro 图像处理工具")
159
+ gr.Markdown("**HuggingFace Spaces 免费版 - 极速优化**")
160
+
161
+ # 系统信息
162
  with gr.Row():
163
+ system_info = gr.Textbox(
164
+ label="系统信息",
165
+ value=get_system_info(),
166
  interactive=False,
167
+ lines=6,
168
+ max_lines=6
169
  )
170
 
171
+ # 主要功能区
172
  with gr.Tabs():
173
+ # 去雾
174
+ with gr.TabItem("🌫️ 去雾"):
 
175
  with gr.Row():
176
  with gr.Column():
177
  dehaze_input = gr.Image(
178
  type="pil",
179
  label="上传有雾图像",
180
+ height=300
181
  )
182
  dehaze_btn = gr.Button("开始去雾", variant="primary")
183
 
 
185
  dehaze_output = gr.Image(
186
  type="pil",
187
  label="去雾结果",
188
+ height=300
189
  )
190
 
191
+ dehaze_btn.click(dehaze_process, dehaze_input, dehaze_output)
 
 
 
 
192
 
193
+ # 去摩尔纹
194
  with gr.TabItem("🔍 去摩尔纹"):
 
195
  with gr.Row():
196
  with gr.Column():
197
  demoiring_input = gr.Image(
198
  type="pil",
199
+ label="上传摩尔纹图像",
200
+ height=300
201
  )
202
  demoiring_btn = gr.Button("开始去摩尔纹", variant="primary")
203
 
 
205
  demoiring_output = gr.Image(
206
  type="pil",
207
  label="去摩尔纹结果",
208
+ height=300
209
  )
210
 
211
+ demoiring_btn.click(demoiring_process, demoiring_input, demoiring_output)
 
 
 
 
212
 
213
+ # 去模糊
214
+ with gr.TabItem("🏃 去模糊"):
 
215
  with gr.Row():
216
  with gr.Column():
217
  deblur_input = gr.Image(
218
  type="pil",
219
+ label="上传模糊图像",
220
+ height=300
221
  )
222
  deblur_btn = gr.Button("开始去模糊", variant="primary")
223
 
 
225
  deblur_output = gr.Image(
226
  type="pil",
227
  label="去模糊结果",
228
+ height=300
229
  )
230
 
231
+ deblur_btn.click(deblur_process, deblur_input, deblur_output)
 
 
 
 
232
 
233
+ # 使用说明 - 简化版
234
  with gr.Accordion("使用说明", open=False):
235
  gr.Markdown("""
236
+ ### HF Spaces 优化说明
237
+
238
+ **🎯 自动优化特性:**
239
+ - 图像自动缩放至512px以内
240
+ - 智能切块减少内存使用
241
+ - 超时保护防止崩溃
242
+ - 内存自动清理
243
+
244
+ **⚠️ 免费版限制:**
245
+ - 处理时间: 约20-30秒
246
+ - 图像尺寸: 自动缩放
247
+ - 并发处理: 不支持
248
+
249
+ **💡 使用建议:**
250
+ - 上传清晰的原图
251
+ - 耐心等待处理完成
252
+ - 大图会自动缩小
 
 
 
 
 
253
  """)
254
 
 
255
  gr.Markdown("""
256
  ---
257
+ **基于 FPro (Frequency Prompted Transformer) | HuggingFace Spaces 优化版**
 
 
258
  """)
259
 
260
  return demo
261
 
262
 
263
  if __name__ == "__main__":
264
+ print("🌐 启动 HuggingFace Spaces 应用...")
265
  demo = create_interface()
266
+ demo.queue(concurrency_count=1, max_size=3) # 限制并发
267
  demo.launch()
model.py CHANGED
@@ -1,262 +1,307 @@
1
- # model.py - 整合版本:支持去雾、去摩尔纹和运动去模糊
2
  import yaml, torch, math, numpy as np
3
  import torch.nn.functional as F
4
  from PIL import Image
5
  from io import BytesIO
6
  from basicsr.models.archs.FPro_arch import FPro
 
7
 
8
- # 强制使用 CPU
 
 
9
  device = torch.device('cpu')
10
 
11
- # 全局变量存储三个模型
12
  dehaze_model = None
13
  demoiring_model = None
14
  deblur_model = None
15
 
 
 
 
 
 
 
 
 
 
16
 
17
- def splitimage(imgtensor, crop_size=256, overlap_size=158):
18
- """切块函数,与原代码保持一致"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  _, C, H, W = imgtensor.shape
20
- hstarts = [x for x in range(0, H, crop_size - overlap_size)]
21
- while hstarts and hstarts[-1] + crop_size >= H:
22
- hstarts.pop()
23
- hstarts.append(H - crop_size)
24
- wstarts = [x for x in range(0, W, crop_size - overlap_size)]
25
- while wstarts and wstarts[-1] + crop_size >= W:
26
- wstarts.pop()
27
- wstarts.append(W - crop_size)
28
- starts = []
 
 
 
 
 
 
 
 
29
  split_data = []
30
- for hs in hstarts:
31
- for ws in wstarts:
32
- cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
33
- starts.append((hs, ws))
34
- split_data.append(cimgdata)
35
- return split_data, starts
36
 
 
 
37
 
38
- def get_scoremap(H, W, C, B=1, is_mean=True):
39
- """权重图生成函数,与原代码保持一致"""
40
- center_h = H / 2
41
- center_w = W / 2
 
 
 
42
 
43
- score = torch.ones((B, C, H, W))
44
- if not is_mean:
45
- for h in range(H):
46
- for w in range(W):
47
- score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
48
- return score
49
 
 
 
 
50
 
51
- def mergeimage(split_data, starts, crop_size=256, resolution=(1, 3, 128, 128)):
52
- """图像合并函数,与原代码保持一致"""
53
- B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
54
- tot_score = torch.zeros((B, C, H, W))
55
- merge_img = torch.zeros((B, C, H, W))
56
- scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
57
 
58
- for simg, cstart in zip(split_data, starts):
59
- hs, ws = cstart
60
- merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
61
- tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
62
 
63
- merge_img = merge_img / tot_score
64
- return merge_img
 
 
 
65
 
 
 
 
 
66
 
67
- def init():
68
- """初始化三个模型"""
69
- global dehaze_model, demoiring_model, deblur_model
70
 
71
- # 初始化去雾模型
72
- try:
73
- print("正在加载去雾模型...")
74
- dehaze_cfg = yaml.safe_load(open("./option/RealDehazing_FPro.yml", "r"))['network_g']
75
- dehaze_cfg.pop('type', None)
76
- dehaze_model = FPro(**dehaze_cfg)
77
- dehaze_model = dehaze_model.to(device)
78
- dehaze_ckpt = torch.load("./model/synDehaze.pth", map_location='cpu')
79
- dehaze_model.load_state_dict(dehaze_ckpt['params'])
80
- dehaze_model.eval()
81
- dehaze_model = dehaze_model.cpu()
82
- print("去雾模型加载成功!")
83
- except Exception as e:
84
- print(f"去雾模型加载失败: {e}")
85
- dehaze_model = None
86
 
87
- # 初始化去摩尔纹模型
88
- try:
89
- print("正在加载去摩尔纹模型...")
90
- # 尝试加载YAML配置文件
91
- try:
92
- demoiring_cfg = yaml.safe_load(open("./option/RealDemoiring_FPro.yml", "r"))['network_g']
93
- demoiring_cfg.pop('type', None)
94
- except FileNotFoundError:
95
- # 如果没有单独的配置文件,使用默认配置
96
- print("未找到去摩尔纹配置文件,使用默认配置")
97
- demoiring_cfg = {
98
- 'inp_channels': 3,
99
- 'out_channels': 3,
100
- 'dim': 48,
101
- 'num_blocks': [4, 6, 6, 8],
102
- 'num_refinement_blocks': 4,
103
- 'heads': [1, 2, 4, 8],
104
- 'ffn_expansion_factor': 2.66,
105
- 'bias': False,
106
- 'LayerNorm_type': 'WithBias',
107
- 'dual_pixel_task': False
108
- }
109
-
110
- demoiring_model = FPro(**demoiring_cfg)
111
- demoiring_model = demoiring_model.to(device)
112
- demoiring_ckpt = torch.load("./model/demoire_noAug.pth", map_location='cpu')
113
- demoiring_model.load_state_dict(demoiring_ckpt['params'])
114
- demoiring_model.eval()
115
- demoiring_model = demoiring_model.cpu()
116
- print("去摩尔纹模型加载成功!")
117
- except Exception as e:
118
- print(f"去摩尔纹模型加载失败: {e}")
119
- demoiring_model = None
120
 
121
- # 初始化运动去模糊模型
 
122
  try:
123
- print("正在加载运动去模糊模型...")
124
- # 尝试加载YAML配置文件
125
  try:
126
- deblur_cfg = yaml.safe_load(open("./option/Deblurring_FPro.yml", "r"))['network_g']
127
- deblur_cfg.pop('type', None)
128
- except FileNotFoundError:
129
- # 如果没有单独的配置文件,使用默认配置
130
- print("未找到运动去模糊配置文件,使用默认配置")
131
- deblur_cfg = {
132
- 'inp_channels': 3,
133
- 'out_channels': 3,
134
- 'dim': 48,
135
- 'num_blocks': [4, 6, 6, 8],
136
- 'num_refinement_blocks': 4,
137
- 'heads': [1, 2, 4, 8],
138
- 'ffn_expansion_factor': 2.66,
139
- 'bias': False,
140
- 'LayerNorm_type': 'WithBias',
141
- 'dual_pixel_task': False
142
- }
143
-
144
- deblur_model = FPro(**deblur_cfg)
145
- deblur_model = deblur_model.to(device)
146
- deblur_ckpt = torch.load("./model/deblur.pth", map_location='cpu')
147
- deblur_model.load_state_dict(deblur_ckpt['params'])
148
- deblur_model.eval()
149
- deblur_model = deblur_model.cpu()
150
- print("运动去模糊模型加载成功!")
 
 
 
 
 
 
151
  except Exception as e:
152
- print(f"运动去模糊模型加载失败: {e}")
153
- deblur_model = None
154
-
155
-
156
- def inference(body: bytes, task_type: str = "dehaze") -> bytes:
157
- """
158
- 推理函数:支持去雾、去摩尔纹和运动去模糊
159
-
160
- Args:
161
- body: 图像字节流
162
- task_type: 任务类型,"dehaze"、"demoiring" 或 "deblur"
163
-
164
- Returns:
165
- 处理后的图像字节流
166
- """
167
- # 选择对应的模型和参数
168
- if task_type == "dehaze":
169
- if dehaze_model is None:
170
- raise Exception("去雾模型未加载")
171
- model = dehaze_model
172
- # 去雾任务的参数
173
- crop_size_arg = 256
174
- overlap_size_arg = 158
175
- elif task_type == "demoiring":
176
- if demoiring_model is None:
177
- raise Exception("去摩尔纹模型未加载")
178
- model = demoiring_model
179
- # 去摩尔纹任务的参数
180
- crop_size_arg = 256
181
- overlap_size_arg = 200
182
- elif task_type == "deblur":
183
- if deblur_model is None:
184
- raise Exception("运动去模糊模型未加载")
185
- model = deblur_model
186
- # 运动去模糊任务的参数
187
- crop_size_arg = 256
188
- overlap_size_arg = 200
189
- else:
190
- raise Exception(f"不支持的任务类型: {task_type}")
191
 
192
- # 加载输入图
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  img = Image.open(BytesIO(body)).convert("RGB")
 
 
 
 
 
 
194
  arr = np.float32(img) / 255.0
195
  t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
196
 
197
- # 确保张量在 CPU 上
198
- t = t.to(device)
199
-
200
- # Padding in case images are not multiples of 8
201
- factor = 8
202
  h, w = t.shape[2], t.shape[3]
203
- H = ((h + factor) // factor) * factor
204
- W = ((w + factor) // factor) * factor
205
- padh = H - h if h % factor != 0 else 0
206
- padw = W - w if w % factor != 0 else 0
207
- t = F.pad(t, (0, padw, 0, padh), 'reflect')
208
 
209
  B, C, H, W = t.shape
210
 
211
- # 如果图像小于等于切块大小,直接处理
212
- if H <= crop_size_arg and W <= crop_size_arg:
213
- with torch.no_grad():
 
214
  restored = model(t)
215
- else:
216
- # 使用切块逻辑
217
- split_data, starts = splitimage(t, crop_size=crop_size_arg, overlap_size=overlap_size_arg)
218
 
219
- # 对每个切块进行推理
220
- with torch.no_grad():
221
- for i, data in enumerate(split_data):
222
- data = data.to(device)
223
- split_data[i] = model(data).cpu()
224
 
225
- # 合并结果
226
- restored = mergeimage(split_data, starts, crop_size=crop_size_arg, resolution=(B, C, H, W))
227
- restored = restored.to(device)
228
 
229
- # Unpad images to original dimensions
 
 
 
230
  restored = restored[:, :, :h, :w]
231
 
232
- # 输出处理
233
- merged = torch.clamp(restored, 0, 1).squeeze(0).permute(1, 2, 0).numpy()
234
- merged = (merged * 255).astype(np.uint8)
235
 
236
- # 输出 PNG
237
- out_img = Image.fromarray(merged)
238
  buf = BytesIO()
239
- out_img.save(buf, format="PNG")
 
 
240
  return buf.getvalue()
241
 
242
 
 
243
  def dehaze_inference(body: bytes) -> bytes:
244
- """去雾推理的便捷函数"""
245
- return inference(body, task_type="dehaze")
246
 
247
 
248
  def demoiring_inference(body: bytes) -> bytes:
249
- """去摩尔纹推理的便捷函数"""
250
- return inference(body, task_type="demoiring")
251
 
252
 
253
  def deblur_inference(body: bytes) -> bytes:
254
- """运动去模糊推理的便捷函数"""
255
- return inference(body, task_type="deblur")
256
 
257
 
258
  def get_model_status():
259
- """获取模型加载状态"""
260
  return {
261
  "dehaze_model_loaded": dehaze_model is not None,
262
  "demoiring_model_loaded": demoiring_model is not None,
 
1
+ # model.py - HuggingFace Spaces极限优化版
2
  import yaml, torch, math, numpy as np
3
  import torch.nn.functional as F
4
  from PIL import Image
5
  from io import BytesIO
6
  from basicsr.models.archs.FPro_arch import FPro
7
+ import gc
8
 
9
+ # HF Spaces环境优化设置
10
+ torch.set_num_threads(2) # HF Spaces通常是2核CPU
11
+ torch.set_num_interop_threads(1)
12
  device = torch.device('cpu')
13
 
14
+ # 全局模型变量
15
  dehaze_model = None
16
  demoiring_model = None
17
  deblur_model = None
18
 
19
+ # 极限优化配置
20
+ EXTREME_OPTIMIZATION = {
21
+ 'max_image_size': 512, # 严格限制最大图像尺寸
22
+ 'min_tile_size': 128, # 最小切块尺寸
23
+ 'max_tiles': 4, # 最大切块数量 (2x2)
24
+ 'overlap_ratio': 0.1, # 最小重叠比例
25
+ 'memory_limit_mb': 1024, # 内存限制
26
+ 'enable_half_precision': False, # HF Spaces上通常不支持
27
+ }
28
 
29
+
30
+ def aggressive_image_resize(image, max_size=512):
31
+ """激进的图像缩放策略"""
32
+ w, h = image.size
33
+
34
+ # 如果图像太大,强制缩小
35
+ if max(w, h) > max_size:
36
+ scale = max_size / max(w, h)
37
+ new_w, new_h = int(w * scale), int(h * scale)
38
+
39
+ # 确保是8的倍数,但优先保证小尺寸
40
+ new_w = ((new_w + 7) // 8) * 8
41
+ new_h = ((new_h + 7) // 8) * 8
42
+
43
+ # 如果调整后还是太大,再次缩小
44
+ if max(new_w, new_h) > max_size:
45
+ new_w = min(new_w, max_size)
46
+ new_h = min(new_h, max_size)
47
+ new_w = (new_w // 8) * 8
48
+ new_h = (new_h // 8) * 8
49
+
50
+ image = image.resize((new_w, new_h), Image.LANCZOS)
51
+ print(f"图像已缩放: {w}x{h} -> {new_w}x{new_h}")
52
+
53
+ return image
54
+
55
+
56
+ def minimal_splitimage(imgtensor, max_tiles=4):
57
+ """最小化切块策略"""
58
  _, C, H, W = imgtensor.shape
59
+
60
+ # 如果图像小于256x256,直接返回
61
+ if H <= 256 and W <= 256:
62
+ return [imgtensor], [(0, 0)]
63
+
64
+ # 计算最优切块配置以限制切块数量
65
+ if max_tiles == 1:
66
+ return [imgtensor], [(0, 0)]
67
+ elif max_tiles == 4: # 2x2
68
+ crop_h = H // 2 if H > 256 else H
69
+ crop_w = W // 2 if W > 256 else W
70
+ overlap = min(32, crop_h // 8, crop_w // 8)
71
+ else:
72
+ # 默认单切块
73
+ return [imgtensor], [(0, 0)]
74
+
75
+ # 简化的2x2切块
76
  split_data = []
77
+ starts = []
 
 
 
 
 
78
 
79
+ h_mid = H // 2
80
+ w_mid = W // 2
81
 
82
+ # 四个切块的位置
83
+ positions = [
84
+ (0, 0, min(crop_h + overlap, H), min(crop_w + overlap, W)),
85
+ (0, max(0, w_mid - overlap), min(crop_h + overlap, H), W),
86
+ (max(0, h_mid - overlap), 0, H, min(crop_w + overlap, W)),
87
+ (max(0, h_mid - overlap), max(0, w_mid - overlap), H, W)
88
+ ]
89
 
90
+ for h_start, w_start, h_end, w_end in positions:
91
+ if h_end <= h_start or w_end <= w_start:
92
+ continue
 
 
 
93
 
94
+ patch = imgtensor[:, :, h_start:h_end, w_start:w_end]
95
+ split_data.append(patch)
96
+ starts.append((h_start, w_start))
97
 
98
+ print(f"图像分割为 {len(split_data)} 个切块")
99
+ return split_data, starts
 
 
 
 
100
 
 
 
 
 
101
 
102
+ def fast_merge(split_data, starts, resolution):
103
+ """快速合并,避免复杂权重计算"""
104
+ B, C, H, W = resolution
105
+ result = torch.zeros((B, C, H, W))
106
+ count = torch.zeros((B, C, H, W))
107
 
108
+ for patch, (h_start, w_start) in zip(split_data, starts):
109
+ patch_h, patch_w = patch.shape[2], patch.shape[3]
110
+ h_end = min(h_start + patch_h, H)
111
+ w_end = min(w_start + patch_w, W)
112
 
113
+ actual_h = h_end - h_start
114
+ actual_w = w_end - w_start
 
115
 
116
+ result[:, :, h_start:h_end, w_start:w_end] += patch[:, :, :actual_h, :actual_w]
117
+ count[:, :, h_start:h_end, w_start:w_end] += 1
118
+
119
+ # 避免除零
120
+ result = result / torch.clamp(count, min=1)
121
+ return result
122
+
123
+
124
+ def memory_cleanup():
125
+ """内存清理"""
126
+ gc.collect()
127
+ if torch.cuda.is_available():
128
+ torch.cuda.empty_cache()
 
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ def load_model_minimal(config_path, weight_path, default_config):
132
+ """最小化模型加载"""
133
  try:
134
+ # 尝试加载配置
 
135
  try:
136
+ with open(config_path, 'r') as f:
137
+ cfg = yaml.safe_load(f)['network_g']
138
+ cfg.pop('type', None)
139
+ except:
140
+ print(f"使用默认配置: {config_path}")
141
+ cfg = default_config
142
+
143
+ # 创建模型
144
+ model = FPro(**cfg)
145
+
146
+ # 加载权重
147
+ checkpoint = torch.load(weight_path, map_location='cpu')
148
+ if 'params' in checkpoint:
149
+ model.load_state_dict(checkpoint['params'])
150
+ else:
151
+ model.load_state_dict(checkpoint)
152
+
153
+ model.eval()
154
+
155
+ # 尝试JIT优化(如果失败就用原始模型)
156
+ try:
157
+ sample_input = torch.randn(1, 3, 128, 128)
158
+ with torch.no_grad():
159
+ traced_model = torch.jit.trace(model, sample_input)
160
+ traced_model.eval()
161
+ print("JIT优化成功")
162
+ return traced_model
163
+ except:
164
+ print("JIT优化失败,使用原始模型")
165
+ return model
166
+
167
  except Exception as e:
168
+ print(f"模型加载失败: {e}")
169
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+
172
+ def init():
173
+ """优化的初始化函数"""
174
+ global dehaze_model, demoiring_model, deblur_model
175
+
176
+ print("HuggingFace Spaces 极限优化模式启动...")
177
+
178
+ # 默认配置
179
+ default_config = {
180
+ 'inp_channels': 3,
181
+ 'out_channels': 3,
182
+ 'dim': 48,
183
+ 'num_blocks': [4, 6, 6, 8],
184
+ 'num_refinement_blocks': 4,
185
+ 'heads': [1, 2, 4, 8],
186
+ 'ffn_expansion_factor': 2.66,
187
+ 'bias': False,
188
+ 'LayerNorm_type': 'WithBias',
189
+ 'dual_pixel_task': False
190
+ }
191
+
192
+ # 去雾模型
193
+ print("加载去雾模型...")
194
+ dehaze_model = load_model_minimal(
195
+ "./option/RealDehazing_FPro.yml",
196
+ "./model/synDehaze.pth",
197
+ default_config
198
+ )
199
+
200
+ # 去摩尔纹模型
201
+ print("加载去摩尔纹模型...")
202
+ demoiring_model = load_model_minimal(
203
+ "./option/RealDemoiring_FPro.yml",
204
+ "./model/demoire_noAug.pth",
205
+ default_config
206
+ )
207
+
208
+ # 运动去模糊模型
209
+ print("加载运动去模糊模型...")
210
+ deblur_model = load_model_minimal(
211
+ "./option/Deblurring_FPro.yml",
212
+ "./model/deblur.pth",
213
+ default_config
214
+ )
215
+
216
+ memory_cleanup()
217
+ print("模型初始化完成")
218
+
219
+
220
+ def ultra_fast_inference(body: bytes, task_type: str = "dehaze") -> bytes:
221
+ """超快推理模式"""
222
+ # 选择模型
223
+ model_map = {
224
+ "dehaze": dehaze_model,
225
+ "demoiring": demoiring_model,
226
+ "deblur": deblur_model
227
+ }
228
+
229
+ model = model_map.get(task_type)
230
+ if model is None:
231
+ raise Exception(f"{task_type}模型未加载")
232
+
233
+ # 加载图像
234
  img = Image.open(BytesIO(body)).convert("RGB")
235
+ print(f"原始尺寸: {img.size}")
236
+
237
+ # 激进缩放
238
+ img = aggressive_image_resize(img, EXTREME_OPTIMIZATION['max_image_size'])
239
+
240
+ # 转换为张量
241
  arr = np.float32(img) / 255.0
242
  t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
243
 
244
+ # 最小padding
 
 
 
 
245
  h, w = t.shape[2], t.shape[3]
246
+ pad_h = (8 - h % 8) % 8
247
+ pad_w = (8 - w % 8) % 8
248
+
249
+ if pad_h > 0 or pad_w > 0:
250
+ t = F.pad(t, (0, pad_w, 0, pad_h), 'reflect')
251
 
252
  B, C, H, W = t.shape
253
 
254
+ # 推理
255
+ with torch.no_grad():
256
+ if H <= 256 and W <= 256:
257
+ # 小图像直接处理
258
  restored = model(t)
259
+ else:
260
+ # 最小化切块处理
261
+ split_data, starts = minimal_splitimage(t, EXTREME_OPTIMIZATION['max_tiles'])
262
 
263
+ # 逐个处理切块
264
+ processed_data = []
265
+ for i, patch in enumerate(split_data):
266
+ if i > 0: # 每处理一个切块就清理内存
267
+ memory_cleanup()
268
 
269
+ processed = model(patch)
270
+ processed_data.append(processed)
 
271
 
272
+ # 快速合并
273
+ restored = fast_merge(processed_data, starts, (B, C, H, W))
274
+
275
+ # 去除padding
276
  restored = restored[:, :, :h, :w]
277
 
278
+ # 快速后处理
279
+ result = torch.clamp(restored, 0, 1).squeeze(0).permute(1, 2, 0).numpy()
280
+ result = (result * 255).astype(np.uint8)
281
 
282
+ # 输出
283
+ out_img = Image.fromarray(result)
284
  buf = BytesIO()
285
+ out_img.save(buf, format="PNG", optimize=True)
286
+
287
+ memory_cleanup()
288
  return buf.getvalue()
289
 
290
 
291
+ # 导出函数
292
  def dehaze_inference(body: bytes) -> bytes:
293
+ return ultra_fast_inference(body, "dehaze")
 
294
 
295
 
296
  def demoiring_inference(body: bytes) -> bytes:
297
+ return ultra_fast_inference(body, "demoiring")
 
298
 
299
 
300
  def deblur_inference(body: bytes) -> bytes:
301
+ return ultra_fast_inference(body, "deblur")
 
302
 
303
 
304
  def get_model_status():
 
305
  return {
306
  "dehaze_model_loaded": dehaze_model is not None,
307
  "demoiring_model_loaded": demoiring_model is not None,