Arrcttacsrks commited on
Commit
0372079
·
verified ·
1 Parent(s): 8d280db

Upload app-22.py

Browse files
Files changed (1) hide show
  1. app-22.py +149 -0
app-22.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import cv2
3
+ import time
4
+ import datetime, pytz
5
+ import gradio as gr
6
+ import torch
7
+ import numpy as np
8
+ from torchvision.utils import save_image
9
+
10
+
11
+ # Import files from the local folder
12
+ root_path = os.path.abspath('.')
13
+ sys.path.append(root_path)
14
+ from test_code.inference import super_resolve_img
15
+ from test_code.test_utils import load_grl, load_rrdb, load_dat
16
+
17
+
18
+ def auto_download_if_needed(weight_path):
19
+ if os.path.exists(weight_path):
20
+ return
21
+
22
+ if not os.path.exists("pretrained"):
23
+ os.makedirs("pretrained")
24
+
25
+ if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
26
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
27
+ os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
28
+
29
+ if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
30
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
31
+ os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
32
+
33
+ if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
34
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
35
+ os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
36
+
37
+ if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
38
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
39
+ os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
40
+
41
+
42
+ def inference(img_path, model_name):
43
+
44
+ try:
45
+ # Load the model
46
+ if model_name == "4xGRL":
47
+ weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
48
+ auto_download_if_needed(weight_path)
49
+ generator = load_grl(weight_path, scale=4)
50
+ generator = generator.to(device='cpu')
51
+
52
+ elif model_name == "4xRRDB":
53
+ weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
54
+ auto_download_if_needed(weight_path)
55
+ generator = load_rrdb(weight_path, scale=4)
56
+ generator = generator.to(device='cpu')
57
+
58
+ elif model_name == "2xRRDB":
59
+ weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
60
+ auto_download_if_needed(weight_path)
61
+ generator = load_rrdb(weight_path, scale=2)
62
+ generator = generator.to(device='cpu')
63
+
64
+ elif model_name == "4xDAT":
65
+ weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
66
+ auto_download_if_needed(weight_path)
67
+ generator = load_dat(weight_path, scale=4)
68
+ generator = generator.to(device='cpu')
69
+
70
+ else:
71
+ raise gr.Error("We don't support such Model")
72
+
73
+
74
+ print("We are processing ", img_path)
75
+ print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
76
+
77
+ # In default, we will automatically use crop to match 4x size
78
+ super_resolved_img = super_resolve_img(generator, img_path, output_path=None, downsample_threshold=720, crop_for_4x=True)
79
+ store_name = str(time.time()) + ".png"
80
+ save_image(super_resolved_img, store_name)
81
+ outputs = cv2.imread(store_name)
82
+ outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
83
+ os.remove(store_name)
84
+
85
+ return outputs
86
+
87
+
88
+ except Exception as error:
89
+ raise gr.Error(f"global exception: {error}")
90
+
91
+
92
+
93
+ if __name__ == '__main__':
94
+
95
+ MARKDOWN = \
96
+ """
97
+ ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
98
+
99
+ [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
100
+
101
+ APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
102
+
103
+ ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
104
+ ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
105
+
106
+ ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
107
+ """
108
+
109
+ block = gr.Blocks().queue(max_size=10)
110
+ with block:
111
+ with gr.Row():
112
+ gr.Markdown(MARKDOWN)
113
+ with gr.Row(elem_classes=["container"]):
114
+ with gr.Column(scale=2):
115
+ input_image = gr.Image(type="filepath", label="Input")
116
+ model_name = gr.Dropdown(
117
+ [
118
+ "2xRRDB",
119
+ "4xRRDB",
120
+ "4xGRL",
121
+ "4xDAT",
122
+ ],
123
+ type="value",
124
+ value="4xGRL",
125
+ label="model",
126
+ )
127
+ run_btn = gr.Button(value="Submit")
128
+
129
+ with gr.Column(scale=3):
130
+ output_image = gr.Image(type="numpy", label="Output image")
131
+
132
+ with gr.Row(elem_classes=["container"]):
133
+ gr.Examples(
134
+ [
135
+ ["__assets__/lr_inputs/image-00277.png"],
136
+ ["__assets__/lr_inputs/image-00542.png"],
137
+ ["__assets__/lr_inputs/41.png"],
138
+ ["__assets__/lr_inputs/f91.jpg"],
139
+ ["__assets__/lr_inputs/image-00440.png"],
140
+ ["__assets__/lr_inputs/image-00164.jpg"],
141
+ ["__assets__/lr_inputs/img_eva.jpeg"],
142
+ ["__assets__/lr_inputs/naruto.jpg"],
143
+ ],
144
+ [input_image],
145
+ )
146
+
147
+ run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
148
+
149
+ block.launch()