Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,342 Bytes
981b0ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import cv2
import numpy as np
import torch.nn.functional as F
import time
import argparse
import os
import os.path as osp
from models.TextEnhancement import MARCONetPlus
from utils.utils_image import get_image_paths, imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN
def inference(input_path=None, output_path=None, aligned=False, bg_sr=False, scale_factor=2, save_text=False, device=None):
if device == None or device == 'gpu':
use_cuda = torch.cuda.is_available()
if device == 'cpu':
use_cuda = False
device = torch.device('cuda' if use_cuda else 'cpu')
if input_path is None:
exit('input image path is none. Please see our document')
if output_path is None:
TIMESTAMP = time.strftime("%m-%d_%H-%M", time.localtime())
if input_path[-1] == '/' or input_path[-1] == '\\':
input_path = input_path[:-1]
output_path = osp.join(input_path+'_'+TIMESTAMP+'_MARCONetPlus')
os.makedirs(output_path, exist_ok=True)
# use bsrgan to restore the background of the whole image
if bg_sr:
##BG model
BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
model_old = torch.load('./checkpoints/bsrgan_bg.pth')
state_dict = BGModel.state_dict()
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
state_dict[key2] = param
BGModel.load_state_dict(state_dict, strict=True)
BGModel.eval()
for k, v in BGModel.named_parameters():
v.requires_grad = False
BGModel = BGModel.to(device)
torch.cuda.empty_cache()
lq_paths = get_image_paths(input_path)
if len(lq_paths) ==0:
exit('No Image in the LR path.')
WEncoderPath='./checkpoints/net_w_encoder_860000.pth'
PriorModelPath='./checkpoints/net_prior_860000.pth'
SRModelPath='./checkpoints/net_sr_860000.pth'
YoloPath = './checkpoints/yolo11m_short_character.pt'
TextModel = MARCONetPlus(WEncoderPath, PriorModelPath, SRModelPath, YoloPath, device=device)
print('{:>25s} : {:s}'.format('Model Name', 'MARCONetPlusPlus'))
if use_cuda:
print('{:>25s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
else:
print('{:>25s} : {:s}'.format('GPU ID', 'No GPU is available. Use CPU instead.'))
torch.cuda.empty_cache()
L_path = input_path
E_path = output_path # save path
print('{:>25s} : {:s}'.format('Input Path', L_path))
print('{:>25s} : {:s}'.format('Output Path', E_path))
if aligned:
print('{:>25s} : {:s}'.format('Image Details', 'Aligned Text Layout. No text detection is used.'))
else:
print('{:>25s} : {:s}'.format('Image Details', 'UnAligned Text Image. It will crop text region using CnSTD, restore, and paste results back.'))
print('{:>25s} : {}'.format('Scale Facter', scale_factor))
print('{:>25s} : {:s}'.format('Save LR & SR text layout', 'True' if save_text else 'False'))
idx = 0
for iix, img_path in enumerate(lq_paths):
####################################
#####(1) Read Image
####################################
idx += 1
img_name, ext = os.path.splitext(os.path.basename(img_path))
print('{:>20s} {:04d} --> {:<s}'.format('Restoring ', idx, img_name+ext))
img_L = imread_uint(img_path, n_channels=3) #RGB 0~255
height_L, width_L = img_L.shape[:2]
if not aligned and bg_sr:
img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation = cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device) #N*C*W*H 0~1
with torch.no_grad():
try:
img_E = BGModel(img_E)
except:
del img_E
torch.cuda.empty_cache()
max_size = 1536
print(' ' * 25 + f' ... Background image is too large... OOM... Resize the image with maximum dimension at most {max_size} pixels')
scale = min(max_size / width_L, max_size / height_L, 1.0)
new_width = int(width_L * scale)
new_height = int(height_L * scale)
img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device)
img_E = BGModel(img_E)
img_E = tensor2uint(img_E)
else:
img_E = img_L
width_S = (width_L * scale_factor)
height_S = (height_L * scale_factor)
img_E = cv2.resize(img_E, (width_S, height_S), interpolation = cv2.INTER_AREA)
############################################################
#####(2) Restore Each Region and Paste to the whole image
############################################################
SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned)
ext = '.png'
if SQ is None or len(en_texts) == 0:
continue
if not aligned:
SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
cv2.imwrite(os.path.join(E_path, img_name+ext), SQ[:,:,::-1].astype(np.uint8))
else:
cv2.imwrite(os.path.join(E_path, img_name+ext), en_texts[0][:,:,::-1].astype(np.uint8))
#####################################################
#####(3) Save Character Prior, location, SR Results
#####################################################
if save_text:
for m, (et, ot, dt, pt) in enumerate(zip(en_texts, ori_texts, debug_texts, pred_texts)): ##save each face region
w, h, c = et.shape
cv2.imwrite(os.path.join(E_path, img_name +'_patch_{}_{}_Debug.jpg'.format(m, pt)), dt[:,:,::-1].astype(np.uint8))
if __name__ == '__main__':
'''
For the whole image: python test_marconetplus.py -i ./Testsets/LR_Whole -b -s -f 2
python test_marconetplus.py -i ./Testsets/LR_TextLines -a -s
'''
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_path', type=str, default='./testsets/LR_Whole', help='The lr text image path')
parser.add_argument('-o', '--output_path', type=str, default=None, help='The save path for text sr result')
parser.add_argument('-a', '--aligned', action='store_true', help='The input text image contains only text region or not, default:False')
parser.add_argument('-b', '--bg_sr', action='store_true', help='When restoring the whole text images, use -b to restore the background region using BSRGAN')
parser.add_argument('-f', '--factor_scale', type=int, default=2, help='When restoring the whole text images, use -f to define the scale factor')
parser.add_argument('-s', '--save_text', action='store_true', help='Save the LR, SR and debug text layout or not')
parser.add_argument('-d', '--device', type=str, default=None, help='using cpu or gpu')
args = parser.parse_args()
inference(args.input_path, args.output_path, args.aligned, args.bg_sr, args.factor_scale, args.save_text, args.device)
|