lqhsmash's picture
Upload 32 files
c52a8b3
raw
history blame contribute delete
No virus
7.47 kB
import argparse
import torch
import torch.nn as nn
from network.Math_Module import P, Q
from network.decom import Decom
import os
#import torchvision
import torchvision.transforms as transforms
from PIL import Image
import time
from utils import *
import cv2
def one2three(x):
return torch.cat([x, x, x], dim=1).to(x)
class Inference(nn.Module):
def __init__(self, opts):
super().__init__()
self.opts = opts
# loading decomposition model
self.model_Decom_low = Decom()
self.model_Decom_low = load_initialize(self.model_Decom_low,
self.opts.Decom_model_low_path)
# loading R; old_model_opts; and L model
self.unfolding_opts, self.model_R, self.model_L = load_unfolding(
self.opts.unfolding_model_path)
# loading adjustment model
self.adjust_model = load_adjustment(self.opts.adjust_model_path)
self.P = P()
self.Q = Q()
transform = [
transforms.ToTensor(),
]
self.transform = transforms.Compose(transform)
print(self.model_Decom_low)
print(self.model_R)
print(self.model_L)
print(self.adjust_model)
#time.sleep(8)
def unfolding(self, input_low_img):
for t in range(self.unfolding_opts.round):
if t == 0: # initialize R0, L0
P, Q = self.model_Decom_low(input_low_img)
else: # update P and Q
w_p = (self.unfolding_opts.gamma +
self.unfolding_opts.Roffset * t)
w_q = (self.unfolding_opts.lamda +
self.unfolding_opts.Loffset * t)
P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q)
R = self.model_R(r=P, l=Q)
L = self.model_L(l=Q)
return R, L
def lllumination_adjust(self, L, ratio):
ratio = torch.ones(L.shape) * self.opts.ratio
return self.adjust_model(l=L, alpha=ratio)
def forward(self, input_low_img):
if torch.cuda.is_available():
input_low_img = input_low_img.cuda()
with torch.no_grad():
start = time.time()
R, L = self.unfolding(input_low_img)
High_L = self.lllumination_adjust(L, self.opts.ratio)
I_enhance = High_L * R
p_time = (time.time() - start)
return I_enhance, p_time
def run(self, low_img_path):
file_name = os.path.basename(self.opts.img_path)
name = file_name.split('.')[0]
low_img = self.transform(Image.open(low_img_path)).unsqueeze(0)
# print('**************************************************************************')
# print(low_img)
# print(type(low_img))
# print(type(Image.open(low_img_path)))
# print(Image.open(low_img_path))
enhance, p_time = self.forward(input_low_img=low_img)
if not os.path.exists(self.opts.output):
os.makedirs(self.opts.output)
save_path = os.path.join(
self.opts.output,
file_name.replace(name,
"%s_%d_URetinexNet" % (name, self.opts.ratio)))
np_save_TensorImg(enhance, save_path)
print(
"================================= time for %s: %f============================"
% (file_name, p_time))
# 这是我自己修改的 run 函数
# 避免了把图片储存到硬盘上面
# 后续也可以修改把图片储存到硬盘上面
def runForWeb(self, image):
# 首先对输入的图片进行下采样直到符合最低运行像素限制
max_pixel_limit=600*600
pyr_down_times=0
while True:
a=len(image)
b=len(image[0])
c=a*b
if(c<=max_pixel_limit):
break
pyr_down_times+=1
image=cv2.pyrDown(image)
print(image.shape)
# 输入
low_img = self.transform(Image.fromarray(np.uint8(image))).unsqueeze(0)
# low_img=Image.fromarray(image.astype('uint8')).convert('RGB')
# print('#############################################')
# print(type(low_img))
# print(low_img)
# 训练
enhance, p_time = self.forward(input_low_img=low_img)
# print('UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU')
# 输出
# 这里需要修改一下 utils.py 的结果放回函数,参考上面 run 函数 np_save_TensorImg 这里需要修改一下的位置
# 退训练结果进行上采样,还原原图大小
result_image=result_for_gradio(enhance)
for i in range(pyr_down_times):
result_image=cv2.pyrUp(result_image)
# return result_for_gradio(enhance)
print(result_image.shape)
return result_image
# 这是提供给 gradio 框架调用的接口
# gradio 框架负责提供后端操控和前端的页面展示
def functionForGradio(image):
parser = argparse.ArgumentParser(description='Configure')
# specify your data path here!
parser.add_argument('--img_path', type=str, default="./demo/input/3.png")
parser.add_argument('--output', type=str, default="./demo/output")
# ratio are recommended to be 3-5, bigger ratio will lead to over-exposure
parser.add_argument('--ratio', type=int, default=5)
# model path
parser.add_argument('--Decom_model_low_path',
type=str,
default="./ckpt/init_low.pth")
parser.add_argument('--unfolding_model_path',
type=str,
default="./ckpt/unfolding.pth")
parser.add_argument('--adjust_model_path',
type=str,
default="./ckpt/L_adjust.pth")
parser.add_argument('--gpu_id', type=int, default=0)
opts = parser.parse_args()
for k, v in vars(opts).items():
print(k, v)
os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
model = Inference(opts)
# 这里传入 numpy 数组然后开始训练
return model.runForWeb(image)
# 这是算法本来的主函数,上面提供的 gradio 框架调用的接口就是修改自主函数
# if __name__ == "__main__":
# parser = argparse.ArgumentParser(description='Configure')
# # specify your data path here!
# parser.add_argument('--img_path', type=str, default="./demo/input/3.png")
# parser.add_argument('--output', type=str, default="./demo/output")
# # ratio are recommended to be 3-5, bigger ratio will lead to over-exposure
# parser.add_argument('--ratio', type=int, default=5)
# # model path
# parser.add_argument('--Decom_model_low_path',
# type=str,
# default="./ckpt/init_low.pth")
# parser.add_argument('--unfolding_model_path',
# type=str,
# default="./ckpt/unfolding.pth")
# parser.add_argument('--adjust_model_path',
# type=str,
# default="./ckpt/L_adjust.pth")
# parser.add_argument('--gpu_id', type=int, default=0)
# opts = parser.parse_args()
# for k, v in vars(opts).items():
# print(k, v)
# os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
# model = Inference(opts)
# model.run(opts.img_path)