to-be commited on
Commit
fca6e54
1 Parent(s): dc15506

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -62
app.py CHANGED
@@ -4,67 +4,87 @@ import os
4
  from skimage import img_as_ubyte
5
  import cv2
6
  import argparse
7
-
8
- parser = argparse.ArgumentParser(description='Test Restormer on your own images')
9
- parser.add_argument('--input_path', default='./temp/image.jpg', type=str, help='Directory of input images or path of single image')
10
- parser.add_argument('--result_dir', default='./temp/', type=str, help='Directory for restored results')
11
- parser.add_argument('--task', required=True, type=str, help='Task to run', choices=['Motion_Deblurring',
12
- 'Single_Image_Defocus_Deblurring',
13
- 'Deraining',
14
- 'Real_Denoising',
15
- 'Gaussian_Gray_Denoising',
16
- 'Gaussian_Color_Denoising'])
17
-
18
- args = parser.parse_args()
19
-
20
-
21
- task = args.task
22
- out_dir = os.path.join(args.result_dir, task)
23
-
24
- os.makedirs(out_dir, exist_ok=True)
25
-
26
-
27
- if task == 'Motion_Deblurring':
28
- model = torch.jit.load('motion_deblurring.pt')
29
- elif task == 'Single_Image_Defocus_Deblurring':
30
- model = torch.jit.load('single_image_defocus_deblurring.pt')
31
- elif task == 'Deraining':
32
- model = torch.jit.load('deraining.pt')
33
- elif task == 'Real_Denoising':
34
- model = torch.jit.load('real_denoising.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
- # device = torch.device('cpu')
38
- # stx()
39
-
40
- model = model.to(device)
41
- model.eval()
42
-
43
- img_multiple_of = 8
44
-
45
- with torch.inference_mode():
46
- if torch.cuda.is_available():
47
- torch.cuda.ipc_collect()
48
- torch.cuda.empty_cache()
49
 
50
- img = cv2.cvtColor(cv2.imread(args.input_path), cv2.COLOR_BGR2RGB)
51
-
52
- input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
53
-
54
- # Pad the input if not_multiple_of 8
55
- h,w = input_.shape[2], input_.shape[3]
56
- H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
57
- padh = H-h if h%img_multiple_of!=0 else 0
58
- padw = W-w if w%img_multiple_of!=0 else 0
59
- input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
60
-
61
- # print(h,w)
62
- restored = torch.clamp(model(input_),0,1)
63
-
64
- # Unpad the output
65
- restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
66
-
67
- out_path = os.path.join(out_dir, os.path.split(args.input_path)[-1])
68
- cv2.imwrite(out_path,cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
69
-
70
- # print(f"\nRestored images are saved at {out_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from skimage import img_as_ubyte
5
  import cv2
6
  import argparse
7
+ import shutil
8
+ import gradio as gr
9
+ from PIL import Image
10
+
11
+ examples = [['sample1.png'],
12
+ ['sample2.png']]
13
+
14
+ inference_on = ['Full Resolution Image', 'Downsampled Image']
15
+
16
+ title = "Restormer"
17
+ description = """
18
+ Gradio demo for <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n
19
+ <b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). But if you want to perform inference on the original input, then choose the option "Full Resolution Image" from the dropdown menu.
20
+ """
21
+ ##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3) Motion Deblurring, and (4) Image Deraining.
22
+ ##To use it, simply upload your own image, or click one of the examples provided below.
23
+
24
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>"
25
+
26
+
27
+ def inference(img, task, run_on):
28
+ if not os.path.exists('temp'):
29
+ os.system('mkdir temp')
30
+
31
+ if run_on == 'Full Resolution Image':
32
+ img = img
33
+ else: # 'Downsampled Image'
34
+ #### Resize the longer edge of the input image
35
+ max_res = 512
36
+ width, height = img.size
37
+ if max(width,height) > max_res:
38
+ scale = max_res /max(width,height)
39
+ width = int(scale*width)
40
+ height = int(scale*height)
41
+ img = img.resize((width,height), Image.ANTIALIAS)
42
+
43
+
44
+ model = torch.jit.load('deraining.pt')
45
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
+ model = model.to(device)
47
+ model.eval()
48
 
49
+ img_multiple_of = 8
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ with torch.inference_mode():
52
+ if torch.cuda.is_available():
53
+ torch.cuda.ipc_collect()
54
+ torch.cuda.empty_cache()
55
+
56
+ img = cv2.cvtColor(cv2.imread(args.input_path), cv2.COLOR_BGR2RGB)
57
+
58
+ input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
59
+
60
+ # Pad the input if not_multiple_of 8
61
+ h,w = input_.shape[2], input_.shape[3]
62
+ H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
63
+ padh = H-h if h%img_multiple_of!=0 else 0
64
+ padw = W-w if w%img_multiple_of!=0 else 0
65
+ input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
66
+
67
+ restored = torch.clamp(model(input_),0,1)
68
+
69
+ # Unpad the output
70
+ restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
71
+ #convert to pil when returning
72
+
73
+ return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
74
+
75
+ gr.Interface(
76
+ inference,
77
+ [
78
+ gr.inputs.Image(type="pil", label="Input"),
79
+ gr.inputs.Radio(["Deraining"], default="Denoising", label='task'),
80
+ gr.inputs.Dropdown(choices=inference_on, type="value", default='Downsampled Image', label='Inference on')
81
+
82
+ ],
83
+ gr.outputs.Image(type="pil", label="cleaned and restored"),
84
+ title=title,
85
+ description=description,
86
+ article=article,
87
+ theme ="huggingface",
88
+ examples=examples,
89
+ allow_flagging=False,
90
+ ).launch(debug=False,enable_queue=True)