File size: 7,342 Bytes
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

import torch
import cv2 
import numpy as np
import torch.nn.functional as F
import time
import argparse
import os
import os.path as osp
from models.TextEnhancement import MARCONetPlus
from utils.utils_image import get_image_paths, imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN




def inference(input_path=None, output_path=None, aligned=False, bg_sr=False, scale_factor=2, save_text=False, device=None):

    if device == None or device == 'gpu':
        use_cuda = torch.cuda.is_available()
    if device == 'cpu':
        use_cuda = False
    device = torch.device('cuda' if use_cuda else 'cpu')

    if input_path is None:
        exit('input image path is none. Please see our document')

    if output_path is None:
        TIMESTAMP = time.strftime("%m-%d_%H-%M", time.localtime())
        if input_path[-1] == '/' or input_path[-1] == '\\':
            input_path = input_path[:-1]
        output_path = osp.join(input_path+'_'+TIMESTAMP+'_MARCONetPlus')
    os.makedirs(output_path, exist_ok=True)

    # use bsrgan to restore the background of the whole image
    if bg_sr: 
        ##BG model
        BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2)  # define network
        model_old = torch.load('./checkpoints/bsrgan_bg.pth')
        state_dict = BGModel.state_dict()
        for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
            state_dict[key2] = param
        BGModel.load_state_dict(state_dict, strict=True)
        BGModel.eval()
        for k, v in BGModel.named_parameters():
            v.requires_grad = False
        BGModel = BGModel.to(device)
        torch.cuda.empty_cache()
    

    lq_paths = get_image_paths(input_path)
    if len(lq_paths) ==0:
        exit('No Image in the LR path.')


    WEncoderPath='./checkpoints/net_w_encoder_860000.pth'
    PriorModelPath='./checkpoints/net_prior_860000.pth'
    SRModelPath='./checkpoints/net_sr_860000.pth'
    YoloPath = './checkpoints/yolo11m_short_character.pt'

    TextModel = MARCONetPlus(WEncoderPath, PriorModelPath, SRModelPath, YoloPath, device=device)

    print('{:>25s} : {:s}'.format('Model Name', 'MARCONetPlusPlus'))
    if use_cuda:
        print('{:>25s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
    else:
        print('{:>25s} : {:s}'.format('GPU ID', 'No GPU is available. Use CPU instead.'))
    torch.cuda.empty_cache()

    L_path = input_path
    E_path = output_path # save path
    print('{:>25s} : {:s}'.format('Input Path', L_path))
    print('{:>25s} : {:s}'.format('Output Path', E_path))
    if aligned:
        print('{:>25s} : {:s}'.format('Image Details', 'Aligned Text Layout. No text detection is used.'))
    else:
        print('{:>25s} : {:s}'.format('Image Details', 'UnAligned Text Image. It will crop text region using CnSTD, restore, and paste results back.'))
        print('{:>25s} : {}'.format('Scale Facter', scale_factor))
    print('{:>25s} : {:s}'.format('Save LR & SR text layout', 'True' if save_text else 'False'))

    idx = 0    

    for iix, img_path in enumerate(lq_paths):
        ####################################
        #####(1) Read Image
        ####################################
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img_path))
        print('{:>20s} {:04d} --> {:<s}'.format('Restoring ', idx, img_name+ext))

        img_L = imread_uint(img_path, n_channels=3) #RGB 0~255
        height_L, width_L = img_L.shape[:2]

        if not aligned and bg_sr:
            img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation = cv2.INTER_AREA)
            img_E = uint2tensor4(img_E).to(device) #N*C*W*H 0~1
            with torch.no_grad():
                try:
                    img_E = BGModel(img_E)
                except:

                    del img_E
                    torch.cuda.empty_cache()
                    max_size = 1536
                    print(' ' * 25 + f' ... Background image is too large... OOM... Resize the image with maximum dimension at most {max_size} pixels')
                    scale = min(max_size / width_L, max_size / height_L, 1.0)  
                    new_width = int(width_L * scale)
                    new_height = int(height_L * scale)
                    img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
                    img_E = uint2tensor4(img_E).to(device)
                    img_E = BGModel(img_E)

            img_E = tensor2uint(img_E)
        else:
            img_E = img_L

        width_S = (width_L * scale_factor)
        height_S = (height_L * scale_factor)
        img_E = cv2.resize(img_E, (width_S, height_S), interpolation = cv2.INTER_AREA)

        ############################################################
        #####(2) Restore Each Region and Paste to the whole image
        ############################################################
        
        SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned)

        
        ext = '.png'
        if SQ is None or len(en_texts) == 0:
            continue
        if not aligned:
            SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
            cv2.imwrite(os.path.join(E_path, img_name+ext), SQ[:,:,::-1].astype(np.uint8))
        else:
            cv2.imwrite(os.path.join(E_path, img_name+ext), en_texts[0][:,:,::-1].astype(np.uint8))

        #####################################################
        #####(3) Save Character Prior, location, SR Results
        #####################################################
        if save_text:
            for m, (et, ot, dt, pt) in enumerate(zip(en_texts, ori_texts, debug_texts, pred_texts)): ##save each face region
                w, h, c = et.shape
                cv2.imwrite(os.path.join(E_path, img_name +'_patch_{}_{}_Debug.jpg'.format(m, pt)), dt[:,:,::-1].astype(np.uint8))

    
if __name__ == '__main__':
    '''
    For the whole image: python test_marconetplus.py -i ./Testsets/LR_Whole -b -s -f 2
    python test_marconetplus.py -i ./Testsets/LR_TextLines -a -s
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input_path', type=str, default='./testsets/LR_Whole', help='The lr text image path')
    parser.add_argument('-o', '--output_path', type=str, default=None, help='The save path for text sr result')
    parser.add_argument('-a', '--aligned', action='store_true', help='The input text image contains only text region or not, default:False')
    parser.add_argument('-b', '--bg_sr', action='store_true', help='When restoring the whole text images, use -b to restore the background region using BSRGAN')
    parser.add_argument('-f', '--factor_scale', type=int, default=2, help='When restoring the whole text images, use -f to define the scale factor')
    parser.add_argument('-s', '--save_text', action='store_true', help='Save the LR, SR and debug text layout or not')
    parser.add_argument('-d', '--device', type=str, default=None, help='using cpu or gpu')

    args = parser.parse_args()
    inference(args.input_path, args.output_path, args.aligned, args.bg_sr, args.factor_scale, args.save_text, args.device)