52Hz commited on
Commit
6f9115d
·
1 Parent(s): 14f7446

Create main_test_SRMNet.py

Browse files
Files changed (1) hide show
  1. main_test_SRMNet.py +94 -0
main_test_SRMNet.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms.functional as TF
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ import os
6
+ from skimage import img_as_ubyte
7
+ from tqdm import tqdm
8
+ from natsort import natsorted
9
+ from glob import glob
10
+ from utils.image_utils import save_img
11
+ from utils.model_utils import load_checkpoint
12
+ import argparse
13
+ from model_arch.SRMNet_SWFF import SRMNet_SWFF
14
+ from model_arch.SRMNet import SRMNet
15
+
16
+ tasks = ['Deblurring_motionblur',
17
+ 'Dehaze_realworld',
18
+ 'Denoise_gaussian',
19
+ 'Denoise_realworld',
20
+ 'Deraining_raindrop',
21
+ 'Deraining_rainstreak',
22
+ 'LLEnhancement',
23
+ 'Retouching']
24
+
25
+ def main():
26
+ parser = argparse.ArgumentParser(description='Quick demo Image Restoration')
27
+ parser.add_argument('--input_dir', default='test/', type=str, help='Input images root')
28
+ parser.add_argument('--result_dir', default='result/', type=str, help='Results images root')
29
+ parser.add_argument('--weights_root', default='pretrained_model', type=str, help='Weights root')
30
+ parser.add_argument('--task', default='Retouching', type=str, help='Restoration task (Above task list)')
31
+
32
+ args = parser.parse_args()
33
+
34
+ # Prepare testing data
35
+ inp_dir = os.path.join(args.input_dir, args.task)
36
+ files = natsorted(glob.glob(os.path.join(inp_dir, '*')))
37
+ if len(files) == 0:
38
+ raise Exception("\nNo images in {} \nPlease enter the following tasks: \n\n{}".format(inp_dir, '\n'.join(tasks)))
39
+
40
+ out_dir = os.path.join(args.result_dir, args.task)
41
+ os.makedirs(out_dir, exist_ok=True)
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+ # Build model
44
+ model = define_model(args)
45
+ model.eval()
46
+ model = model.to(device)
47
+
48
+ print('restoring images......')
49
+
50
+ mul = 16
51
+
52
+ for i, file_ in enumerate(tqdm(files)):
53
+ img = Image.open(file_).convert('RGB')
54
+ input_ = TF.to_tensor(img).unsqueeze(0).cuda()
55
+
56
+ # Pad the input if not_multiple_of 8
57
+ h, w = input_.shape[2], input_.shape[3]
58
+ H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
59
+ padh = H - h if h % mul != 0 else 0
60
+ padw = W - w if w % mul != 0 else 0
61
+ input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
62
+ with torch.no_grad():
63
+ restored = model(input_)
64
+
65
+ restored = torch.clamp(restored, 0, 1)
66
+ restored = restored[:, :, :h, :w]
67
+ restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
68
+ restored = img_as_ubyte(restored[0])
69
+
70
+ f = os.path.splitext(os.path.split(file_)[-1])[0]
71
+ save_img((os.path.join(out_dir, f + '.png')), restored)
72
+
73
+ print(f"Files saved at {out_dir}")
74
+ print('finish !')
75
+
76
+
77
+ def define_model(args):
78
+ # Enhance models
79
+ if args.task in ['LLEnhancement', 'Retouching']:
80
+ model = SRMNet(in_chn=3, wf=96, depth=4)
81
+ weight_path = os.path.join(args.weights_root, args.task + '.pth')
82
+ load_checkpoint(model, weight_path)
83
+
84
+ # Restored models
85
+ else:
86
+ model = SRMNet_SWFF(in_chn=3, wf=96, depth=4)
87
+ weight_path = os.path.join(args.weights_root, args.task + '.pth')
88
+ load_checkpoint(model, weight_path)
89
+
90
+ return model
91
+
92
+
93
+ if __name__ == '__main__':
94
+ main()