Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import cv2 | |
import numpy as np | |
import torch.nn.functional as F | |
import time | |
import argparse | |
import os | |
import os.path as osp | |
from models.TextEnhancement import MARCONetPlus | |
from utils.utils_image import get_image_paths, imread_uint, uint2tensor4, tensor2uint | |
from networks.rrdbnet2_arch import RRDBNet as BSRGAN | |
def inference(input_path=None, output_path=None, aligned=False, bg_sr=False, scale_factor=2, save_text=False, device=None): | |
if device == None or device == 'gpu': | |
use_cuda = torch.cuda.is_available() | |
if device == 'cpu': | |
use_cuda = False | |
device = torch.device('cuda' if use_cuda else 'cpu') | |
if input_path is None: | |
exit('input image path is none. Please see our document') | |
if output_path is None: | |
TIMESTAMP = time.strftime("%m-%d_%H-%M", time.localtime()) | |
if input_path[-1] == '/' or input_path[-1] == '\\': | |
input_path = input_path[:-1] | |
output_path = osp.join(input_path+'_'+TIMESTAMP+'_MARCONetPlus') | |
os.makedirs(output_path, exist_ok=True) | |
# use bsrgan to restore the background of the whole image | |
if bg_sr: | |
##BG model | |
BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network | |
model_old = torch.load('./checkpoints/bsrgan_bg.pth') | |
state_dict = BGModel.state_dict() | |
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): | |
state_dict[key2] = param | |
BGModel.load_state_dict(state_dict, strict=True) | |
BGModel.eval() | |
for k, v in BGModel.named_parameters(): | |
v.requires_grad = False | |
BGModel = BGModel.to(device) | |
torch.cuda.empty_cache() | |
lq_paths = get_image_paths(input_path) | |
if len(lq_paths) ==0: | |
exit('No Image in the LR path.') | |
WEncoderPath='./checkpoints/net_w_encoder_860000.pth' | |
PriorModelPath='./checkpoints/net_prior_860000.pth' | |
SRModelPath='./checkpoints/net_sr_860000.pth' | |
YoloPath = './checkpoints/yolo11m_short_character.pt' | |
TextModel = MARCONetPlus(WEncoderPath, PriorModelPath, SRModelPath, YoloPath, device=device) | |
print('{:>25s} : {:s}'.format('Model Name', 'MARCONetPlusPlus')) | |
if use_cuda: | |
print('{:>25s} : {:<d}'.format('GPU ID', torch.cuda.current_device())) | |
else: | |
print('{:>25s} : {:s}'.format('GPU ID', 'No GPU is available. Use CPU instead.')) | |
torch.cuda.empty_cache() | |
L_path = input_path | |
E_path = output_path # save path | |
print('{:>25s} : {:s}'.format('Input Path', L_path)) | |
print('{:>25s} : {:s}'.format('Output Path', E_path)) | |
if aligned: | |
print('{:>25s} : {:s}'.format('Image Details', 'Aligned Text Layout. No text detection is used.')) | |
else: | |
print('{:>25s} : {:s}'.format('Image Details', 'UnAligned Text Image. It will crop text region using CnSTD, restore, and paste results back.')) | |
print('{:>25s} : {}'.format('Scale Facter', scale_factor)) | |
print('{:>25s} : {:s}'.format('Save LR & SR text layout', 'True' if save_text else 'False')) | |
idx = 0 | |
for iix, img_path in enumerate(lq_paths): | |
#################################### | |
#####(1) Read Image | |
#################################### | |
idx += 1 | |
img_name, ext = os.path.splitext(os.path.basename(img_path)) | |
print('{:>20s} {:04d} --> {:<s}'.format('Restoring ', idx, img_name+ext)) | |
img_L = imread_uint(img_path, n_channels=3) #RGB 0~255 | |
height_L, width_L = img_L.shape[:2] | |
if not aligned and bg_sr: | |
img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation = cv2.INTER_AREA) | |
img_E = uint2tensor4(img_E).to(device) #N*C*W*H 0~1 | |
with torch.no_grad(): | |
try: | |
img_E = BGModel(img_E) | |
except: | |
del img_E | |
torch.cuda.empty_cache() | |
max_size = 1536 | |
print(' ' * 25 + f' ... Background image is too large... OOM... Resize the image with maximum dimension at most {max_size} pixels') | |
scale = min(max_size / width_L, max_size / height_L, 1.0) | |
new_width = int(width_L * scale) | |
new_height = int(height_L * scale) | |
img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA) | |
img_E = uint2tensor4(img_E).to(device) | |
img_E = BGModel(img_E) | |
img_E = tensor2uint(img_E) | |
else: | |
img_E = img_L | |
width_S = (width_L * scale_factor) | |
height_S = (height_L * scale_factor) | |
img_E = cv2.resize(img_E, (width_S, height_S), interpolation = cv2.INTER_AREA) | |
############################################################ | |
#####(2) Restore Each Region and Paste to the whole image | |
############################################################ | |
SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned) | |
ext = '.png' | |
if SQ is None or len(en_texts) == 0: | |
continue | |
if not aligned: | |
SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA) | |
cv2.imwrite(os.path.join(E_path, img_name+ext), SQ[:,:,::-1].astype(np.uint8)) | |
else: | |
cv2.imwrite(os.path.join(E_path, img_name+ext), en_texts[0][:,:,::-1].astype(np.uint8)) | |
##################################################### | |
#####(3) Save Character Prior, location, SR Results | |
##################################################### | |
if save_text: | |
for m, (et, ot, dt, pt) in enumerate(zip(en_texts, ori_texts, debug_texts, pred_texts)): ##save each face region | |
w, h, c = et.shape | |
cv2.imwrite(os.path.join(E_path, img_name +'_patch_{}_{}_Debug.jpg'.format(m, pt)), dt[:,:,::-1].astype(np.uint8)) | |
if __name__ == '__main__': | |
''' | |
For the whole image: python test_marconetplus.py -i ./Testsets/LR_Whole -b -s -f 2 | |
python test_marconetplus.py -i ./Testsets/LR_TextLines -a -s | |
''' | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-i', '--input_path', type=str, default='./testsets/LR_Whole', help='The lr text image path') | |
parser.add_argument('-o', '--output_path', type=str, default=None, help='The save path for text sr result') | |
parser.add_argument('-a', '--aligned', action='store_true', help='The input text image contains only text region or not, default:False') | |
parser.add_argument('-b', '--bg_sr', action='store_true', help='When restoring the whole text images, use -b to restore the background region using BSRGAN') | |
parser.add_argument('-f', '--factor_scale', type=int, default=2, help='When restoring the whole text images, use -f to define the scale factor') | |
parser.add_argument('-s', '--save_text', action='store_true', help='Save the LR, SR and debug text layout or not') | |
parser.add_argument('-d', '--device', type=str, default=None, help='using cpu or gpu') | |
args = parser.parse_args() | |
inference(args.input_path, args.output_path, args.aligned, args.bg_sr, args.factor_scale, args.save_text, args.device) | |