marconetplusplus / test_marconetplus.py
csxmli's picture
Upload
981b0ab verified
import torch
import cv2
import numpy as np
import torch.nn.functional as F
import time
import argparse
import os
import os.path as osp
from models.TextEnhancement import MARCONetPlus
from utils.utils_image import get_image_paths, imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN
def inference(input_path=None, output_path=None, aligned=False, bg_sr=False, scale_factor=2, save_text=False, device=None):
if device == None or device == 'gpu':
use_cuda = torch.cuda.is_available()
if device == 'cpu':
use_cuda = False
device = torch.device('cuda' if use_cuda else 'cpu')
if input_path is None:
exit('input image path is none. Please see our document')
if output_path is None:
TIMESTAMP = time.strftime("%m-%d_%H-%M", time.localtime())
if input_path[-1] == '/' or input_path[-1] == '\\':
input_path = input_path[:-1]
output_path = osp.join(input_path+'_'+TIMESTAMP+'_MARCONetPlus')
os.makedirs(output_path, exist_ok=True)
# use bsrgan to restore the background of the whole image
if bg_sr:
##BG model
BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
model_old = torch.load('./checkpoints/bsrgan_bg.pth')
state_dict = BGModel.state_dict()
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
state_dict[key2] = param
BGModel.load_state_dict(state_dict, strict=True)
BGModel.eval()
for k, v in BGModel.named_parameters():
v.requires_grad = False
BGModel = BGModel.to(device)
torch.cuda.empty_cache()
lq_paths = get_image_paths(input_path)
if len(lq_paths) ==0:
exit('No Image in the LR path.')
WEncoderPath='./checkpoints/net_w_encoder_860000.pth'
PriorModelPath='./checkpoints/net_prior_860000.pth'
SRModelPath='./checkpoints/net_sr_860000.pth'
YoloPath = './checkpoints/yolo11m_short_character.pt'
TextModel = MARCONetPlus(WEncoderPath, PriorModelPath, SRModelPath, YoloPath, device=device)
print('{:>25s} : {:s}'.format('Model Name', 'MARCONetPlusPlus'))
if use_cuda:
print('{:>25s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
else:
print('{:>25s} : {:s}'.format('GPU ID', 'No GPU is available. Use CPU instead.'))
torch.cuda.empty_cache()
L_path = input_path
E_path = output_path # save path
print('{:>25s} : {:s}'.format('Input Path', L_path))
print('{:>25s} : {:s}'.format('Output Path', E_path))
if aligned:
print('{:>25s} : {:s}'.format('Image Details', 'Aligned Text Layout. No text detection is used.'))
else:
print('{:>25s} : {:s}'.format('Image Details', 'UnAligned Text Image. It will crop text region using CnSTD, restore, and paste results back.'))
print('{:>25s} : {}'.format('Scale Facter', scale_factor))
print('{:>25s} : {:s}'.format('Save LR & SR text layout', 'True' if save_text else 'False'))
idx = 0
for iix, img_path in enumerate(lq_paths):
####################################
#####(1) Read Image
####################################
idx += 1
img_name, ext = os.path.splitext(os.path.basename(img_path))
print('{:>20s} {:04d} --> {:<s}'.format('Restoring ', idx, img_name+ext))
img_L = imread_uint(img_path, n_channels=3) #RGB 0~255
height_L, width_L = img_L.shape[:2]
if not aligned and bg_sr:
img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation = cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device) #N*C*W*H 0~1
with torch.no_grad():
try:
img_E = BGModel(img_E)
except:
del img_E
torch.cuda.empty_cache()
max_size = 1536
print(' ' * 25 + f' ... Background image is too large... OOM... Resize the image with maximum dimension at most {max_size} pixels')
scale = min(max_size / width_L, max_size / height_L, 1.0)
new_width = int(width_L * scale)
new_height = int(height_L * scale)
img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device)
img_E = BGModel(img_E)
img_E = tensor2uint(img_E)
else:
img_E = img_L
width_S = (width_L * scale_factor)
height_S = (height_L * scale_factor)
img_E = cv2.resize(img_E, (width_S, height_S), interpolation = cv2.INTER_AREA)
############################################################
#####(2) Restore Each Region and Paste to the whole image
############################################################
SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned)
ext = '.png'
if SQ is None or len(en_texts) == 0:
continue
if not aligned:
SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
cv2.imwrite(os.path.join(E_path, img_name+ext), SQ[:,:,::-1].astype(np.uint8))
else:
cv2.imwrite(os.path.join(E_path, img_name+ext), en_texts[0][:,:,::-1].astype(np.uint8))
#####################################################
#####(3) Save Character Prior, location, SR Results
#####################################################
if save_text:
for m, (et, ot, dt, pt) in enumerate(zip(en_texts, ori_texts, debug_texts, pred_texts)): ##save each face region
w, h, c = et.shape
cv2.imwrite(os.path.join(E_path, img_name +'_patch_{}_{}_Debug.jpg'.format(m, pt)), dt[:,:,::-1].astype(np.uint8))
if __name__ == '__main__':
'''
For the whole image: python test_marconetplus.py -i ./Testsets/LR_Whole -b -s -f 2
python test_marconetplus.py -i ./Testsets/LR_TextLines -a -s
'''
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_path', type=str, default='./testsets/LR_Whole', help='The lr text image path')
parser.add_argument('-o', '--output_path', type=str, default=None, help='The save path for text sr result')
parser.add_argument('-a', '--aligned', action='store_true', help='The input text image contains only text region or not, default:False')
parser.add_argument('-b', '--bg_sr', action='store_true', help='When restoring the whole text images, use -b to restore the background region using BSRGAN')
parser.add_argument('-f', '--factor_scale', type=int, default=2, help='When restoring the whole text images, use -f to define the scale factor')
parser.add_argument('-s', '--save_text', action='store_true', help='Save the LR, SR and debug text layout or not')
parser.add_argument('-d', '--device', type=str, default=None, help='using cpu or gpu')
args = parser.parse_args()
inference(args.input_path, args.output_path, args.aligned, args.bg_sr, args.factor_scale, args.save_text, args.device)