52Hz commited on
Commit
44b3c84
1 Parent(s): d4f141f

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +83 -0
predict.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cog
2
+ import tempfile
3
+ from pathlib import Path
4
+ import argparse
5
+ import shutil
6
+ import os
7
+ import glob
8
+ import torch
9
+ from skimage import img_as_ubyte
10
+ from PIL import Image
11
+ from model.SRMNet import SRMNet
12
+ from main_test_SRMNet import save_img, setup
13
+ import torchvision.transforms.functional as TF
14
+ import torch.nn.functional as F
15
+
16
+ class Predictor(cog.Predictor):
17
+ def setup(self):
18
+ model_dir = 'experiments/pretrained_models/AWGN_denoising_SRMNet.pth'
19
+
20
+ parser = argparse.ArgumentParser(description='Demo Image Denoising')
21
+ parser.add_argument('--input_dir', default='./test/', type=str, help='Input images')
22
+ parser.add_argument('--result_dir', default='./result/', type=str, help='Directory for results')
23
+ parser.add_argument('--weights',
24
+ default='./checkpoints/SRMNet_real_denoise/models/model_bestPSNR.pth', type=str,
25
+ help='Path to weights')
26
+
27
+ self.args = parser.parse_args()
28
+
29
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+ @cog.input("image", type=Path, help="input image")
32
+
33
+ def predict(self, image):
34
+
35
+ # set input folder
36
+ input_dir = 'input_cog_temp'
37
+ os.makedirs(input_dir, exist_ok=True)
38
+ input_path = os.path.join(input_dir, os.path.basename(image))
39
+ shutil.copy(str(image), input_path)
40
+
41
+ # Load corresponding models architecture and weights
42
+ model = SRMNet()
43
+ model.eval()
44
+ model = model.to(self.device)
45
+
46
+ folder, save_dir = setup(self.args)
47
+ os.makedirs(save_dir, exist_ok=True)
48
+
49
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
50
+ mul = 16
51
+ for file_ in sorted(glob.glob(os.path.join(folder, '*.PNG'))):
52
+ img = Image.open(file_).convert('RGB')
53
+ input_ = TF.to_tensor(img).unsqueeze(0).cuda()
54
+
55
+ # Pad the input if not_multiple_of 8
56
+ h, w = input_.shape[2], input_.shape[3]
57
+ H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
58
+ padh = H - h if h % mul != 0 else 0
59
+ padw = W - w if w % mul != 0 else 0
60
+ input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
61
+ with torch.no_grad():
62
+ restored = model(input_)
63
+
64
+ restored = torch.clamp(restored, 0, 1)
65
+ restored = restored[:, :, :h, :w]
66
+ restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
67
+ restored = img_as_ubyte(restored[0])
68
+
69
+ save_img(str(out_path), restored)
70
+ clean_folder(input_dir)
71
+ return out_path
72
+
73
+
74
+ def clean_folder(folder):
75
+ for filename in os.listdir(folder):
76
+ file_path = os.path.join(folder, filename)
77
+ try:
78
+ if os.path.isfile(file_path) or os.path.islink(file_path):
79
+ os.unlink(file_path)
80
+ elif os.path.isdir(file_path):
81
+ shutil.rmtree(file_path)
82
+ except Exception as e:
83
+ print('Failed to delete %s. Reason: %s' % (file_path, e))