documents-restoration / inference_gradio.py
qubvel-hf's picture
qubvel-hf HF staff
Init model on GPU
e7c7d09
import sys
import cv2
import utils
import numpy as np
import torch
from PIL import Image
from utils import convert_state_dict
from models import restormer_arch
from data.preprocess.crop_merge_image import stride_integral
sys.path.append("./data/MBD/")
from data.MBD.infer import net1_net2_infer_single_im
def dewarp_prompt(img):
mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl")
base_coord = utils.getBasecoord(256, 256) / 256
img[mask == 0] = 0
mask = cv2.resize(mask, (256, 256)) / 255
return img, np.concatenate((base_coord, np.expand_dims(mask, -1)), -1)
def deshadow_prompt(img):
h, w = img.shape[:2]
# img = cv2.resize(img,(128,128))
img = cv2.resize(img, (1024, 1024))
rgb_planes = cv2.split(img)
result_planes = []
result_norm_planes = []
bg_imgs = []
for plane in rgb_planes:
dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
bg_img = cv2.medianBlur(dilated_img, 21)
bg_imgs.append(bg_img)
diff_img = 255 - cv2.absdiff(plane, bg_img)
norm_img = cv2.normalize(
diff_img,
None,
alpha=0,
beta=255,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_8UC1,
)
result_planes.append(diff_img)
result_norm_planes.append(norm_img)
bg_imgs = cv2.merge(bg_imgs)
bg_imgs = cv2.resize(bg_imgs, (w, h))
# result = cv2.merge(result_planes)
result_norm = cv2.merge(result_norm_planes)
result_norm[result_norm == 0] = 1
shadow_map = np.clip(
img.astype(float) / result_norm.astype(float) * 255, 0, 255
).astype(np.uint8)
shadow_map = cv2.resize(shadow_map, (w, h))
shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_BGR2GRAY)
shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_GRAY2BGR)
# return shadow_map
return bg_imgs
def deblur_prompt(img):
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
absX = cv2.convertScaleAbs(x) # 转回uint8
absY = cv2.convertScaleAbs(y)
high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY)
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_GRAY2BGR)
return high_frequency
def appearance_prompt(img):
h, w = img.shape[:2]
# img = cv2.resize(img,(128,128))
img = cv2.resize(img, (1024, 1024))
rgb_planes = cv2.split(img)
result_planes = []
result_norm_planes = []
for plane in rgb_planes:
dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
bg_img = cv2.medianBlur(dilated_img, 21)
diff_img = 255 - cv2.absdiff(plane, bg_img)
norm_img = cv2.normalize(
diff_img,
None,
alpha=0,
beta=255,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_8UC1,
)
result_planes.append(diff_img)
result_norm_planes.append(norm_img)
result_norm = cv2.merge(result_norm_planes)
result_norm = cv2.resize(result_norm, (w, h))
return result_norm
def binarization_promptv2(img):
result, thresh = utils.SauvolaModBinarization(img)
thresh = thresh.astype(np.uint8)
result[result > 155] = 255
result[result <= 155] = 0
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
absX = cv2.convertScaleAbs(x) # 转回uint8
absY = cv2.convertScaleAbs(y)
high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY)
return np.concatenate(
(
np.expand_dims(thresh, -1),
np.expand_dims(high_frequency, -1),
np.expand_dims(result, -1),
),
-1,
)
def dewarping(model, im_org, device):
INPUT_SIZE = 256
im_masked, prompt_org = dewarp_prompt(im_org.copy())
h, w = im_masked.shape[:2]
im_masked = im_masked.copy()
im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE))
im_masked = im_masked / 255.0
im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0)
im_masked = im_masked.float().to(device)
prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0)
prompt = prompt.float().to(device)
in_im = torch.cat((im_masked, prompt), dim=1)
# inference
base_coord = utils.getBasecoord(INPUT_SIZE, INPUT_SIZE) / INPUT_SIZE
model = model.float()
with torch.no_grad():
pred = model(in_im)
pred = pred[0][:2].permute(1, 2, 0).cpu().numpy()
pred = pred + base_coord
## smooth
for i in range(15):
pred = cv2.blur(pred, (3, 3), borderType=cv2.BORDER_REPLICATE)
pred = cv2.resize(pred, (w, h)) * (w, h)
pred = pred.astype(np.float32)
out_im = cv2.remap(im_org, pred[:, :, 0], pred[:, :, 1], cv2.INTER_LINEAR)
prompt_org = (prompt_org * 255).astype(np.uint8)
prompt_org = cv2.resize(prompt_org, im_org.shape[:2][::-1])
return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im
def appearance(model, im_org, device):
MAX_SIZE = 1600
# obtain im and prompt
h, w = im_org.shape[:2]
prompt = appearance_prompt(im_org)
in_im = np.concatenate((im_org, prompt), -1)
# constrain the max resolution
if max(w, h) < MAX_SIZE:
in_im, padding_h, padding_w = stride_integral(in_im, 8)
else:
in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE))
# normalize
in_im = in_im / 255.0
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
# inference
in_im = in_im.half().to(device)
model = model.half()
with torch.no_grad():
pred = model(in_im)
pred = torch.clamp(pred, 0, 1)
pred = pred[0].permute(1, 2, 0).cpu().numpy()
pred = (pred * 255).astype(np.uint8)
if max(w, h) < MAX_SIZE:
out_im = pred[padding_h:, padding_w:]
else:
pred[pred == 0] = 1
shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype(
float
) / pred.astype(float)
shadow_map = cv2.resize(shadow_map, (w, h))
shadow_map[shadow_map == 0] = 0.00001
out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8)
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
def deshadowing(model, im_org, device):
MAX_SIZE = 1600
# obtain im and prompt
h, w = im_org.shape[:2]
prompt = deshadow_prompt(im_org)
in_im = np.concatenate((im_org, prompt), -1)
# constrain the max resolution
if max(w, h) < MAX_SIZE:
in_im, padding_h, padding_w = stride_integral(in_im, 8)
else:
in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE))
# normalize
in_im = in_im / 255.0
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
# inference
in_im = in_im.half().to(device)
model = model.half()
with torch.no_grad():
pred = model(in_im)
pred = torch.clamp(pred, 0, 1)
pred = pred[0].permute(1, 2, 0).cpu().numpy()
pred = (pred * 255).astype(np.uint8)
if max(w, h) < MAX_SIZE:
out_im = pred[padding_h:, padding_w:]
else:
pred[pred == 0] = 1
shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype(
float
) / pred.astype(float)
shadow_map = cv2.resize(shadow_map, (w, h))
shadow_map[shadow_map == 0] = 0.00001
out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8)
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
def deblurring(model, im_org, device):
# setup image
in_im, padding_h, padding_w = stride_integral(im_org, 8)
prompt = deblur_prompt(in_im)
in_im = np.concatenate((in_im, prompt), -1)
in_im = in_im / 255.0
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
in_im = in_im.half().to(device)
# inference
model.to(device)
model.eval()
model = model.half()
with torch.no_grad():
pred = model(in_im)
pred = torch.clamp(pred, 0, 1)
pred = pred[0].permute(1, 2, 0).cpu().numpy()
pred = (pred * 255).astype(np.uint8)
out_im = pred[padding_h:, padding_w:]
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
def binarization(model, im_org, device):
im, padding_h, padding_w = stride_integral(im_org, 8)
prompt = binarization_promptv2(im)
h, w = im.shape[:2]
in_im = np.concatenate((im, prompt), -1)
in_im = in_im / 255.0
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
in_im = in_im.to(device)
model = model.half()
in_im = in_im.half()
with torch.no_grad():
pred = model(in_im)
pred = pred[:, :2, :, :]
pred = torch.max(torch.softmax(pred, 1), 1)[1]
pred = pred[0].cpu().numpy()
pred = (pred * 255).astype(np.uint8)
pred = cv2.resize(pred, (w, h))
out_im = pred[padding_h:, padding_w:]
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
def model_init(model_path, device):
# prepare model
model = restormer_arch.Restormer(
inp_channels=6,
out_channels=3,
dim=48,
num_blocks=[2, 3, 3, 4],
num_refinement_blocks=4,
heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type="WithBias",
dual_pixel_task=True,
)
if device == "cpu":
state = convert_state_dict(
torch.load(model_path, map_location="cpu")["model_state"]
)
else:
state = convert_state_dict(
torch.load(model_path, map_location="cuda:0")["model_state"]
)
model.load_state_dict(state)
model.eval()
model = model.to(device)
return model
def resize(image, max_size):
h, w = image.shape[:2]
if max(h, w) > max_size:
if h > w:
h_new = max_size
w_new = int(w * h_new / h)
else:
w_new = max_size
h_new = int(h * w_new / w)
pil_image = Image.fromarray(image)
pil_image = pil_image.resize((w_new, h_new), Image.Resampling.LANCZOS)
image = np.array(pil_image)
return image
def inference_one_image(model, image, tasks, device):
# image should be in BGR format
if "dewarping" in tasks:
*_, image = dewarping(model, image, device)
# if only dewarping return here
if len(tasks) == 1 and "dewarping" in tasks:
return image
image = resize(image, 1536)
if "deshadowing" in tasks:
*_, image = deshadowing(model, image, device)
if "appearance" in tasks:
*_, image = appearance(model, image, device)
if "deblurring" in tasks:
*_, image = deblurring(model, image, device)
if "binarization" in tasks:
*_, image = binarization(model, image, device)
return image