File size: 6,139 Bytes
9bf54b1
 
 
561c629
 
8f3d49d
9bf54b1
561c629
 
 
 
 
 
 
 
 
 
9bf54b1
561c629
 
 
 
 
 
 
 
 
720b377
 
 
 
561c629
 
 
 
 
 
 
9bf54b1
 
 
 
720b377
561c629
 
 
 
 
 
 
 
 
 
 
 
 
720b377
 
 
 
 
561c629
 
 
 
 
9bf54b1
 
 
 
 
561c629
8f3d49d
561c629
 
 
 
9bf54b1
 
 
561c629
d26bbd5
8f3d49d
 
 
561c629
d26bbd5
561c629
 
 
 
 
 
 
 
 
 
 
 
 
8f3d49d
 
561c629
 
72f81a3
 
d26bbd5
9bf54b1
720b377
9bf54b1
561c629
 
8f3d49d
561c629
 
 
 
 
 
 
 
 
720b377
9bf54b1
 
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d00ac8
561c629
d26bbd5
561c629
 
 
 
 
 
9bf54b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
'''
    Gradio demo (almost the same code as the one used in Huggingface space)
'''
import os, sys
import cv2
import time
import datetime, pytz
import gradio as gr
import torch
import numpy as np
from torchvision.utils import save_image


# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from test_code.inference import super_resolve_img
from test_code.test_utils import load_grl, load_rrdb, load_dat


def auto_download_if_needed(weight_path):
    if os.path.exists(weight_path):
        return
    
    if not os.path.exists("pretrained"):
        os.makedirs("pretrained")
    
    if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
        os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
    
    if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
        os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
        
    if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
        os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
    
    if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
        os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
    


def inference(img_path, model_name):
    
    try:
        weight_dtype = torch.float32
        
        # Load the model
        if model_name == "4xGRL":
            weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
            auto_download_if_needed(weight_path)
            generator = load_grl(weight_path, scale=4)  # Directly use default way now
            
        elif model_name == "4xRRDB":
            weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
            auto_download_if_needed(weight_path)
            generator = load_rrdb(weight_path, scale=4)  # Directly use default way now
            
        elif model_name == "2xRRDB":
            weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
            auto_download_if_needed(weight_path)
            generator = load_rrdb(weight_path, scale=2) # Directly use default way now
            
        elif model_name == "4xDAT":
            weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
            auto_download_if_needed(weight_path)
            generator = load_dat(weight_path, scale=4) # Directly use default way now
            
        else:
            raise gr.Error("We don't support such Model")
        
        generator = generator.to(dtype=weight_dtype)


        print("We are processing ", img_path)
        print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))

        # In default, we will automatically use crop to match 4x size
        super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
        store_name = str(time.time()) + ".png"
        save_image(super_resolved_img, store_name)
        outputs = cv2.imread(store_name)
        outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
        os.remove(store_name)
        
        return outputs
    
    
    except Exception as error:
        raise gr.Error(f"global exception: {error}")



if __name__ == '__main__':
    
    MARKDOWN = \
    """
    ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
    
    [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)

    APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
    
    ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio.  E.g., 1920x1080 -> 1280x720
    ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
    
    ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
    """

    block = gr.Blocks().queue(max_size=10)
    with block:
        with gr.Row():
            gr.Markdown(MARKDOWN)
        with gr.Row(elem_classes=["container"]):
            with gr.Column(scale=2):
                input_image = gr.Image(type="filepath", label="Input")
                model_name = gr.Dropdown(
                    [
                        "2xRRDB",
                        "4xRRDB",
                        "4xGRL",
                        "4xDAT",
                    ],
                    type="value",
                    value="4xGRL",
                    label="model",
                )
                run_btn = gr.Button(value="Submit")

            with gr.Column(scale=3):
                output_image = gr.Image(type="numpy", label="Output image")

        with gr.Row(elem_classes=["container"]):
            gr.Examples(
                [
                    ["__assets__/lr_inputs/image-00277.png"],
                    ["__assets__/lr_inputs/image-00542.png"],
                    ["__assets__/lr_inputs/41.png"],
                    ["__assets__/lr_inputs/f91.jpg"],
                    ["__assets__/lr_inputs/image-00440.png"],
                    ["__assets__/lr_inputs/image-00164.jpg"],
                    ["__assets__/lr_inputs/img_eva.jpeg"],
                    ["__assets__/lr_inputs/naruto.jpg"],
                ],
                [input_image],
            )

        run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])

    block.launch()