52Hz commited on
Commit
de39981
1 Parent(s): 0759c30

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +82 -0
predict.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cog
2
+ import tempfile
3
+ from pathlib import Path
4
+ import argparse
5
+ import shutil
6
+ import os
7
+ import glob
8
+ import torch
9
+ from skimage import img_as_ubyte
10
+ from PIL import Image
11
+ from model.SRMNet import SRMNet
12
+ from main_test_SRMNet import save_img, setup
13
+ import torchvision.transforms.functional as TF
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class Predictor(cog.Predictor):
18
+ def setup(self):
19
+ model_dir = 'experiments/pretrained_models/AWGN_denoising_SRMNet.pth'
20
+
21
+ parser = argparse.ArgumentParser(description='Demo Image Denoising')
22
+ parser.add_argument('--input_dir', default='./test/', type=str, help='Input images')
23
+ parser.add_argument('--result_dir', default='./result/', type=str, help='Directory for results')
24
+ parser.add_argument('--weights',
25
+ default='./checkpoints/SRMNet_real_denoise/models/model_bestPSNR.pth', type=str,
26
+ help='Path to weights')
27
+
28
+ self.args = parser.parse_args()
29
+
30
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+
32
+ @cog.input("image", type=Path, help="input image")
33
+ def predict(self, image):
34
+ # set input folder
35
+ input_dir = 'input_cog_temp'
36
+ os.makedirs(input_dir, exist_ok=True)
37
+ input_path = os.path.join(input_dir, os.path.basename(image))
38
+ shutil.copy(str(image), input_path)
39
+
40
+ # Load corresponding models architecture and weights
41
+ model = SRMNet()
42
+ model.eval()
43
+ model = model.to(self.device)
44
+
45
+ folder, save_dir = setup(self.args)
46
+ os.makedirs(save_dir, exist_ok=True)
47
+
48
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
49
+ mul = 16
50
+ for file_ in sorted(glob.glob(os.path.join(folder, '*.PNG'))):
51
+ img = Image.open(file_).convert('RGB')
52
+ input_ = TF.to_tensor(img).unsqueeze(0).cuda()
53
+
54
+ # Pad the input if not_multiple_of 8
55
+ h, w = input_.shape[2], input_.shape[3]
56
+ H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
57
+ padh = H - h if h % mul != 0 else 0
58
+ padw = W - w if w % mul != 0 else 0
59
+ input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
60
+ with torch.no_grad():
61
+ restored = model(input_)
62
+
63
+ restored = torch.clamp(restored, 0, 1)
64
+ restored = restored[:, :, :h, :w]
65
+ restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
66
+ restored = img_as_ubyte(restored[0])
67
+
68
+ save_img(str(out_path), restored)
69
+ clean_folder(input_dir)
70
+ return out_path
71
+
72
+
73
+ def clean_folder(folder):
74
+ for filename in os.listdir(folder):
75
+ file_path = os.path.join(folder, filename)
76
+ try:
77
+ if os.path.isfile(file_path) or os.path.islink(file_path):
78
+ os.unlink(file_path)
79
+ elif os.path.isdir(file_path):
80
+ shutil.rmtree(file_path)
81
+ except Exception as e:
82
+ print('Failed to delete %s. Reason: %s' % (file_path, e))