Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import cv2 | |
import utils | |
import argparse | |
import numpy as np | |
import torch | |
from utils import convert_state_dict | |
from models import restormer_arch | |
from data.preprocess.crop_merge_image import stride_integral | |
os.sys.path.append('./data/MBD/') | |
from data.MBD.infer import net1_net2_infer_single_im | |
def dewarp_prompt(img): | |
mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl') | |
base_coord = utils.getBasecoord(256,256)/256 | |
img[mask==0]=0 | |
mask = cv2.resize(mask,(256,256))/255 | |
return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1) | |
def deshadow_prompt(img): | |
h,w = img.shape[:2] | |
# img = cv2.resize(img,(128,128)) | |
img = cv2.resize(img,(1024,1024)) | |
rgb_planes = cv2.split(img) | |
result_planes = [] | |
result_norm_planes = [] | |
bg_imgs = [] | |
for plane in rgb_planes: | |
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8)) | |
bg_img = cv2.medianBlur(dilated_img, 21) | |
bg_imgs.append(bg_img) | |
diff_img = 255 - cv2.absdiff(plane, bg_img) | |
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) | |
result_planes.append(diff_img) | |
result_norm_planes.append(norm_img) | |
bg_imgs = cv2.merge(bg_imgs) | |
bg_imgs = cv2.resize(bg_imgs,(w,h)) | |
# result = cv2.merge(result_planes) | |
result_norm = cv2.merge(result_norm_planes) | |
result_norm[result_norm==0]=1 | |
shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8) | |
shadow_map = cv2.resize(shadow_map,(w,h)) | |
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY) | |
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR) | |
# return shadow_map | |
return bg_imgs | |
def deblur_prompt(img): | |
x = cv2.Sobel(img,cv2.CV_16S,1,0) | |
y = cv2.Sobel(img,cv2.CV_16S,0,1) | |
absX = cv2.convertScaleAbs(x) # 转回uint8 | |
absY = cv2.convertScaleAbs(y) | |
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0) | |
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY) | |
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR) | |
return high_frequency | |
def appearance_prompt(img): | |
h,w = img.shape[:2] | |
# img = cv2.resize(img,(128,128)) | |
img = cv2.resize(img,(1024,1024)) | |
rgb_planes = cv2.split(img) | |
result_planes = [] | |
result_norm_planes = [] | |
for plane in rgb_planes: | |
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8)) | |
bg_img = cv2.medianBlur(dilated_img, 21) | |
diff_img = 255 - cv2.absdiff(plane, bg_img) | |
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) | |
result_planes.append(diff_img) | |
result_norm_planes.append(norm_img) | |
result_norm = cv2.merge(result_norm_planes) | |
result_norm = cv2.resize(result_norm,(w,h)) | |
return result_norm | |
def binarization_promptv2(img): | |
result,thresh = utils.SauvolaModBinarization(img) | |
thresh = thresh.astype(np.uint8) | |
result[result>155]=255 | |
result[result<=155]=0 | |
x = cv2.Sobel(img,cv2.CV_16S,1,0) | |
y = cv2.Sobel(img,cv2.CV_16S,0,1) | |
absX = cv2.convertScaleAbs(x) # 转回uint8 | |
absY = cv2.convertScaleAbs(y) | |
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0) | |
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY) | |
return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1) | |
def dewarping(model,im_path): | |
INPUT_SIZE=256 | |
im_org = cv2.imread(im_path) | |
im_masked, prompt_org = dewarp_prompt(im_org.copy()) | |
h,w = im_masked.shape[:2] | |
im_masked = im_masked.copy() | |
im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE)) | |
im_masked = im_masked / 255.0 | |
im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0) | |
im_masked = im_masked.float().to(DEVICE) | |
prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0) | |
prompt = prompt.float().to(DEVICE) | |
in_im = torch.cat((im_masked,prompt),dim=1) | |
# inference | |
base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE | |
model = model.float() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = pred[0][:2].permute(1,2,0).cpu().numpy() | |
pred = pred+base_coord | |
## smooth | |
for i in range(15): | |
pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE) | |
pred = cv2.resize(pred,(w,h))*(w,h) | |
pred = pred.astype(np.float32) | |
out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR) | |
prompt_org = (prompt_org*255).astype(np.uint8) | |
prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1]) | |
return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im | |
def appearance(model,im_path): | |
MAX_SIZE=1600 | |
# obtain im and prompt | |
im_org = cv2.imread(im_path) | |
h,w = im_org.shape[:2] | |
prompt = appearance_prompt(im_org) | |
in_im = np.concatenate((im_org,prompt),-1) | |
# constrain the max resolution | |
if max(w,h) < MAX_SIZE: | |
in_im,padding_h,padding_w = stride_integral(in_im,8) | |
else: | |
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE)) | |
# normalize | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) | |
# inference | |
in_im = in_im.half().to(DEVICE) | |
model = model.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = torch.clamp(pred,0,1) | |
pred = pred[0].permute(1,2,0).cpu().numpy() | |
pred = (pred*255).astype(np.uint8) | |
if max(w,h) < MAX_SIZE: | |
out_im = pred[padding_h:,padding_w:] | |
else: | |
pred[pred==0] = 1 | |
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float) | |
shadow_map = cv2.resize(shadow_map,(w,h)) | |
shadow_map[shadow_map==0]=0.00001 | |
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8) | |
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im | |
def deshadowing(model,im_path): | |
MAX_SIZE=1600 | |
# obtain im and prompt | |
im_org = cv2.imread(im_path) | |
h,w = im_org.shape[:2] | |
prompt = deshadow_prompt(im_org) | |
in_im = np.concatenate((im_org,prompt),-1) | |
# constrain the max resolution | |
if max(w,h) < MAX_SIZE: | |
in_im,padding_h,padding_w = stride_integral(in_im,8) | |
else: | |
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE)) | |
# normalize | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) | |
# inference | |
in_im = in_im.half().to(DEVICE) | |
model = model.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = torch.clamp(pred,0,1) | |
pred = pred[0].permute(1,2,0).cpu().numpy() | |
pred = (pred*255).astype(np.uint8) | |
if max(w,h) < MAX_SIZE: | |
out_im = pred[padding_h:,padding_w:] | |
else: | |
pred[pred==0]=1 | |
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float) | |
shadow_map = cv2.resize(shadow_map,(w,h)) | |
shadow_map[shadow_map==0]=0.00001 | |
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8) | |
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im | |
def deblurring(model,im_path): | |
# setup image | |
im_org = cv2.imread(im_path) | |
in_im,padding_h,padding_w = stride_integral(im_org,8) | |
prompt = deblur_prompt(in_im) | |
in_im = np.concatenate((in_im,prompt),-1) | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) | |
in_im = in_im.half().to(DEVICE) | |
# inference | |
model.to(DEVICE) | |
model.eval() | |
model = model.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = torch.clamp(pred,0,1) | |
pred = pred[0].permute(1,2,0).cpu().numpy() | |
pred = (pred*255).astype(np.uint8) | |
out_im = pred[padding_h:,padding_w:] | |
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im | |
def binarization(model,im_path): | |
im_org = cv2.imread(im_path) | |
im,padding_h,padding_w = stride_integral(im_org,8) | |
prompt = binarization_promptv2(im) | |
h,w = im.shape[:2] | |
in_im = np.concatenate((im,prompt),-1) | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) | |
in_im = in_im.to(DEVICE) | |
model = model.half() | |
in_im = in_im.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = pred[:,:2,:,:] | |
pred = torch.max(torch.softmax(pred,1),1)[1] | |
pred = pred[0].cpu().numpy() | |
pred = (pred*255).astype(np.uint8) | |
pred = cv2.resize(pred,(w,h)) | |
out_im = pred[padding_h:,padding_w:] | |
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im | |
def get_args(): | |
parser = argparse.ArgumentParser(description='Params') | |
parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint') | |
parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/', | |
help='Path of input document image') | |
parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/', | |
help='Folder of the output images') | |
parser.add_argument('--task', nargs='?', type=str, default='dewarping', | |
help='task that need to be executed') | |
parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0, | |
help='Width of the input image') | |
args = parser.parse_args() | |
possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end'] | |
assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks) | |
return args | |
def model_init(args): | |
# prepare model | |
model = restormer_arch.Restormer( | |
inp_channels=6, | |
out_channels=3, | |
dim = 48, | |
num_blocks = [2,3,3,4], | |
num_refinement_blocks = 4, | |
heads = [1,2,4,8], | |
ffn_expansion_factor = 2.66, | |
bias = False, | |
LayerNorm_type = 'WithBias', | |
dual_pixel_task = True | |
) | |
if DEVICE.type == 'cpu': | |
state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state']) | |
else: | |
state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state']) | |
model.load_state_dict(state) | |
model.eval() | |
model = model.to(DEVICE) | |
return model | |
def inference_one_im(model,im_path,task): | |
if task=='dewarping': | |
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path) | |
elif task=='deshadowing': | |
prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path) | |
elif task=='appearance': | |
prompt1,prompt2,prompt3,restorted = appearance(model,im_path) | |
elif task=='deblurring': | |
prompt1,prompt2,prompt3,restorted = deblurring(model,im_path) | |
elif task=='binarization': | |
prompt1,prompt2,prompt3,restorted = binarization(model,im_path) | |
elif task=='end2end': | |
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path) | |
cv2.imwrite('restorted/step1.jpg',restorted) | |
prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg') | |
cv2.imwrite('restorted/step2.jpg',restorted) | |
prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg') | |
# os.remove('restorted/step1.jpg') | |
# os.remove('restorted/step2.jpg') | |
return prompt1,prompt2,prompt3,restorted | |
if __name__ == '__main__': | |
## model init | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
args = get_args() | |
model = model_init(args) | |
## inference | |
prompt1,prompt2,prompt3,restorted = inference_one_im(model,args.im_path,args.task) | |
## results saving | |
im_name = os.path.split(args.im_path)[-1] | |
im_format = '.'+im_name.split('.')[-1] | |
save_path = os.path.join(args.out_folder,im_name.replace(im_format,'_'+args.task+im_format)) | |
cv2.imwrite(save_path,restorted) | |
if args.save_dtsprompt: | |
cv2.imwrite(save_path.replace(im_format,'_prompt1'+im_format),prompt1) | |
cv2.imwrite(save_path.replace(im_format,'_prompt2'+im_format),prompt2) | |
cv2.imwrite(save_path.replace(im_format,'_prompt3'+im_format),prompt3) | |