FPro-Dehaze / model.py
yssszzzzzzzzy's picture
Upload 2 files
ba83289 verified
# model.py - 整合去雨滴和去雨功能的完整版本
import yaml, torch, math, numpy as np
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from basicsr.models.archs.FPro_arch import FPro
import gc
# 强制使用 CPU
torch.set_num_threads(2)
torch.set_num_interop_threads(1)
device = torch.device('cpu')
# 全局模型变量
dehaze_model = None
demoiring_model = None
deblur_model = None
deraindrop_model = None # 新增:去雨滴模型
derain_model = None # 新增:去雨模型
def splitimage(imgtensor, crop_size=128, overlap_size=64):
"""原始切块函数 - 与测试代码完全一致"""
_, C, H, W = imgtensor.shape
hstarts = [x for x in range(0, H, crop_size - overlap_size)]
while hstarts and hstarts[-1] + crop_size >= H:
hstarts.pop()
hstarts.append(H - crop_size)
wstarts = [x for x in range(0, W, crop_size - overlap_size)]
while wstarts and wstarts[-1] + crop_size >= W:
wstarts.pop()
wstarts.append(W - crop_size)
starts = []
split_data = []
for hs in hstarts:
for ws in wstarts:
cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
starts.append((hs, ws))
split_data.append(cimgdata)
return split_data, starts
def get_scoremap(H, W, C, B=1, is_mean=True):
"""原始权重图生成函数"""
center_h = H / 2
center_w = W / 2
score = torch.ones((B, C, H, W))
if not is_mean:
for h in range(H):
for w in range(W):
score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
return score
def mergeimage(split_data, starts, crop_size=128, resolution=(1, 3, 128, 128)):
"""原始合并函数"""
B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
tot_score = torch.zeros((B, C, H, W))
merge_img = torch.zeros((B, C, H, W))
scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
for simg, cstart in zip(split_data, starts):
hs, ws = cstart
merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
merge_img = merge_img / tot_score
return merge_img
def load_model_safe(config_path, weight_path, default_config):
"""安全加载模型"""
try:
# 尝试加载配置文件
try:
with open(config_path, 'r') as f:
x = yaml.safe_load(f)
cfg = x['network_g'].copy()
cfg.pop('type', None)
except:
print(f"配置文件加载失败,使用默认配置: {config_path}")
cfg = default_config
# 创建模型
model = FPro(**cfg)
# 加载权重
checkpoint = torch.load(weight_path, map_location='cpu')
if 'params' in checkpoint:
model.load_state_dict(checkpoint['params'])
else:
model.load_state_dict(checkpoint)
model.eval()
model = model.to(device)
print(f"模型加载成功: {weight_path}")
return model
except Exception as e:
print(f"模型加载失败: {e}")
return None
def init():
"""初始化所有模型"""
global dehaze_model, demoiring_model, deblur_model, deraindrop_model, derain_model
print("开始加载模型...")
# 默认配置
default_config = {
'inp_channels': 3,
'out_channels': 3,
'dim': 48,
'num_blocks': [4, 6, 6, 8],
'num_refinement_blocks': 4,
'heads': [1, 2, 4, 8],
'ffn_expansion_factor': 2.66,
'bias': False,
'LayerNorm_type': 'WithBias',
'dual_pixel_task': False
}
# 加载去雾模型 - 使用切块处理
dehaze_model = load_model_safe(
"./option/RealDehazing_FPro.yml",
"./model/synDehaze.pth",
default_config
)
# 加载去摩尔纹模型 - 直接处理
demoiring_model = load_model_safe(
"./option/RealDemoiring_FPro.yml",
"./model/demoire_noAug.pth",
default_config
)
# 加载去模糊模型 - 使用切块处理
deblur_model = load_model_safe(
"./option/Deblurring_FPro.yml",
"./model/deblur.pth",
default_config
)
# 新增:加载去雨滴模型 - 使用切块处理
deraindrop_model = load_model_safe(
"./option/RealDeraindrop_FPro.yml",
"./model/deraindrop_proIR.pth",
default_config
)
# 新增:加载去雨模型 - 使用切块处理
derain_model = load_model_safe(
"./option/Deraining_FPro_spad.yml",
"./model/derain_spad.pth",
default_config
)
# 清理内存
gc.collect()
print("所有模型加载完成")
def preprocess_image(img):
"""预处理图像 - 基于原始代码逻辑"""
# 转换为numpy数组并归一化
arr = np.float32(img) / 255.0
# 转换为torch张量
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor
def postprocess_tensor(tensor, original_h, original_w):
"""后处理张量 - 基于原始代码逻辑"""
# 裁剪到原始尺寸
tensor = tensor[:, :, :original_h, :original_w]
# 转换为numpy并输出
result = torch.clamp(tensor, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
result = (result * 255).astype(np.uint8)
return result
def dehaze_inference_impl(input_tensor):
"""去雾推理实现 - 基于原始test_SOTS.py"""
factor = 8
h, w = input_tensor.shape[2], input_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
B, C, H, W = input_tensor.shape
corp_size_arg = 256
overlap_size_arg = 158
split_data, starts = splitimage(input_tensor, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
with torch.no_grad():
for i, data in enumerate(split_data):
split_data[i] = dehaze_model(data).cpu()
restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
return restored, h, w
def demoiring_inference_impl(input_tensor):
"""去摩尔纹推理实现 - 基于原始test_moire.py (直接处理)"""
factor = 8
h, w = input_tensor.shape[2], input_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
with torch.no_grad():
restored = demoiring_model(input_tensor)
return restored, h, w
def deblur_inference_impl(input_tensor):
"""去模糊推理实现 - 基于原始test_FPro.py"""
B, C, H, W = input_tensor.shape
corp_size_arg = 256
overlap_size_arg = 200
split_data, starts = splitimage(input_tensor, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
with torch.no_grad():
for i, data in enumerate(split_data):
split_data[i] = deblur_model(data).cpu()
restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
return restored, H, W
def deraindrop_inference_impl(input_tensor):
"""去雨滴推理实现 - 基于原始测试代码"""
factor = 8
h, w = input_tensor.shape[2], input_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
B, C, H, W = input_tensor.shape
corp_size_arg = 256
overlap_size_arg = 200
split_data, starts = splitimage(input_tensor, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
with torch.no_grad():
for i, data in enumerate(split_data):
split_data[i] = deraindrop_model(data).cpu()
restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
return restored, h, w
def derain_inference_impl(input_tensor):
"""去雨推理实现 - 基于原始测试代码"""
factor = 8
h, w = input_tensor.shape[2], input_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
B, C, H, W = input_tensor.shape
corp_size_arg = 256
overlap_size_arg = 200
split_data, starts = splitimage(input_tensor, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
with torch.no_grad():
for i, data in enumerate(split_data):
split_data[i] = derain_model(data).cpu()
restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
return restored, h, w
def inference(body: bytes, task_type: str = "dehaze") -> bytes:
"""统一推理接口 - 保持原图尺寸"""
# 选择模型
model_map = {
"dehaze": dehaze_model,
"demoiring": demoiring_model,
"deblur": deblur_model,
"deraindrop": deraindrop_model, # 新增
"derain": derain_model # 新增
}
model = model_map.get(task_type)
if model is None:
raise Exception(f"{task_type}模型未加载")
# 加载并预处理图像 - 保持原图尺寸
img = Image.open(BytesIO(body)).convert("RGB")
print(f"原始图像尺寸: {img.size}")
# 直接使用原图,不进行缩放
input_tensor = preprocess_image(img)
input_tensor = input_tensor.to(device)
try:
# 根据任务类型选择推理方法
if task_type == "dehaze":
restored, orig_h, orig_w = dehaze_inference_impl(input_tensor)
elif task_type == "demoiring":
restored, orig_h, orig_w = demoiring_inference_impl(input_tensor)
elif task_type == "deblur":
restored, orig_h, orig_w = deblur_inference_impl(input_tensor)
elif task_type == "deraindrop":
restored, orig_h, orig_w = deraindrop_inference_impl(input_tensor)
elif task_type == "derain":
restored, orig_h, orig_w = derain_inference_impl(input_tensor)
else:
raise ValueError(f"不支持的任务类型: {task_type}")
# 后处理
result = postprocess_tensor(restored, orig_h, orig_w)
# 转换为PIL图像并输出
out_img = Image.fromarray(result)
buf = BytesIO()
out_img.save(buf, format="PNG", optimize=True)
# 清理内存
gc.collect()
return buf.getvalue()
except Exception as e:
print(f"推理失败: {e}")
# 如果是内存不足错误,可以考虑其他处理方式
raise e
# 导出函数
def dehaze_inference(body: bytes) -> bytes:
return inference(body, "dehaze")
def demoiring_inference(body: bytes) -> bytes:
return inference(body, "demoiring")
def deblur_inference(body: bytes) -> bytes:
return inference(body, "deblur")
def deraindrop_inference(body: bytes) -> bytes:
return inference(body, "deraindrop")
def derain_inference(body: bytes) -> bytes:
return inference(body, "derain")
def get_model_status():
return {
"dehaze_model_loaded": dehaze_model is not None,
"demoiring_model_loaded": demoiring_model is not None,
"deblur_model_loaded": deblur_model is not None,
"deraindrop_model_loaded": deraindrop_model is not None, # 新增
"derain_model_loaded": derain_model is not None # 新增
}