File size: 6,229 Bytes
561c629
 
 
 
 
 
8f3d49d
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
d26bbd5
561c629
 
 
 
 
 
d26bbd5
561c629
 
 
 
 
 
d26bbd5
 
 
 
 
 
 
 
 
 
561c629
 
 
 
 
 
 
8f3d49d
 
 
 
 
 
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f3d49d
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f3d49d
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
'''
    This is file is to execute the inference for a single image or a folder input
'''
import argparse
import os, sys, cv2, shutil, warnings
import torch
import gradio as gr
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
warnings.simplefilter("default")
os.environ["PYTHONWARNINGS"] = "default"


# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from test_code.test_utils import load_grl, load_rrdb, load_cunet



@torch.no_grad      # You must add these time, else it will have Out of Memory
def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torch.float32, downsample_threshold=720, crop_for_4x=True):
    ''' Super Resolve a low resolution image
    Args:
        generator (torch):              the generator class that is already loaded
        input_path (str):               the path to the input lr images
        output_path (str):              the directory to store the generated images
        weight_dtype (bool):            the weight type (float32/float16)
        downsample_threshold (int):     the threshold of height/width (short side) to downsample the input
        crop_for_4x (bool):             whether we crop the lr images to match 4x scale (needed for some situation)
    '''
    print("Processing image {}".format(input_path))
    
    # Read the image and do preprocess
    img_lr = cv2.imread(input_path)
    h, w, c = img_lr.shape


    # Downsample if needed
    short_side = min(h, w)
    if downsample_threshold != -1 and short_side > downsample_threshold:
        resize_ratio = short_side / downsample_threshold
        img_lr = cv2.resize(img_lr, (int(w/resize_ratio), int(h/resize_ratio)), interpolation = cv2.INTER_LINEAR)


    # Crop if needed
    if crop_for_4x:
        h, w, _ = img_lr.shape
        if h % 4 != 0:
            img_lr = img_lr[:4*(h//4),:,:]
        if w % 4 != 0:
            img_lr = img_lr[:,:4*(w//4),:]
                
    # Check if the size is out of the boundary
    h, w, c = img_lr.shape
    if h*w > 720*1280:
        raise gr.Error("The input image size is too large. The largest area we support is 720x1280=921600 pixel!")
    

    # Transform to tensor
    img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB)
    img_lr = ToTensor()(img_lr).unsqueeze(0).cuda()     # Use tensor format
    img_lr = img_lr.to(dtype=weight_dtype)
    
    
    # Model inference
    print("lr shape is ", img_lr.shape)
    super_resolved_img = generator(img_lr)

    # Store the generated result
    with torch.cuda.amp.autocast():
        if output_path is not None:
            save_image(super_resolved_img, output_path)

    # Empty the cache every time you finish processing one image
    torch.cuda.empty_cache() 
    
    return super_resolved_img




if __name__ == "__main__":
    
    # Fundamental setting
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type = str, default = '__assets__/lr_inputs', help="Can be either single image input or a folder input")
    parser.add_argument('--model', type = str, default = 'GRL', help=" 'GRL' || 'RRDB' (for ESRNET & ESRGAN) || 'CUNET' (for Real-ESRGAN) ")
    parser.add_argument('--scale', type = int, default = 4, help="Up scaler factor")
    parser.add_argument('--weight_path', type = str, default = 'pretrained/4x_APISR_GRL_GAN_generator.pth', help="Weight path directory, usually under saved_models folder")
    parser.add_argument('--store_dir', type = str, default = 'sample_outputs', help="The folder to store the super-resolved images")
    parser.add_argument('--float16_inference', type = bool, default = False, help="Float16 inference, only useful in RRDB now")      # Currently, this is only supported in RRDB, there is some bug with GRL model
    args = parser.parse_args()
    
    # Sample Command
    # 4x GRL (Default):     python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth
    # 2x RRDB:              python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth


    # Read argument and prepare the folder needed
    input_dir = args.input_dir
    model = args.model
    weight_path = args.weight_path
    store_dir = args.store_dir
    scale = args.scale
    float16_inference = args.float16_inference
    
    
    # Check the path of the weight
    if not os.path.exists(weight_path):
        print("we cannot locate weight path ", weight_path) 
        # TODO: I am not sure if I should automatically download weight from github release based on the upscale factor and model name.
        os._exit(0)
    
    
    # Prepare the store folder
    if os.path.exists(store_dir):
        shutil.rmtree(store_dir)
    os.makedirs(store_dir)



    # Define the weight type
    if float16_inference:
        torch.backends.cudnn.benchmark = True
        weight_dtype = torch.float16
    else:
        weight_dtype = torch.float32
        

    # Load the model
    if model == "GRL":
        generator = load_grl(weight_path, scale=scale)  # GRL for Real-World SR only support 4x upscaling
    elif model == "RRDB":
        generator = load_rrdb(weight_path, scale=scale)  # Can be any size
    generator = generator.to(dtype=weight_dtype)
    

    # Take the input path and do inference
    if os.path.isdir(store_dir):    # If the input is a directory, we will iterate it
        for filename in sorted(os.listdir(input_dir)):
            input_path = os.path.join(input_dir, filename)
            output_path = os.path.join(store_dir, filename)
            # In default, we will automatically use crop to match 4x size
            super_resolve_img(generator, input_path, output_path, weight_dtype, crop_for_4x=True)
            
    else:   # If the input is a single image, we will process it directly and write on the same folder
        filename = os.path.split(input_dir)[-1].split('.')[0]
        output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png")
        # In default, we will automatically use crop to match 4x size
        super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True)