HaoFeng2019 commited on
Commit
26f23ad
1 Parent(s): 532251c

Upload inference_ill.py

Browse files
Files changed (1) hide show
  1. inference_ill.py +134 -0
inference_ill.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from skimage.filters.rank import mean_bilateral
5
+ from skimage import morphology
6
+ from PIL import Image
7
+ from PIL import ImageEnhance
8
+
9
+
10
+ def padCropImg(img):
11
+ H = img.shape[0]
12
+ W = img.shape[1]
13
+
14
+ patchRes = 128
15
+ pH = patchRes
16
+ pW = patchRes
17
+ ovlp = int(patchRes * 0.125) # 32
18
+
19
+ padH = (int((H - patchRes) / (patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H
20
+ padW = (int((W - patchRes) / (patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W
21
+
22
+ padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE)
23
+
24
+ ynum = int((padImg.shape[0] - pH) / (pH - ovlp)) + 1
25
+ xnum = int((padImg.shape[1] - pW) / (pW - ovlp)) + 1
26
+
27
+ totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8)
28
+
29
+ for j in range(0, ynum):
30
+ for i in range(0, xnum):
31
+ x = int(i * (pW - ovlp))
32
+ y = int(j * (pH - ovlp))
33
+
34
+ if j == (ynum-1) and i == (xnum-1):
35
+ totalPatch[j, i] = img[-patchRes:, -patchRes:]
36
+ elif j == (ynum-1):
37
+ totalPatch[j, i] = img[-patchRes:, x:int(x + patchRes)]
38
+ elif i == (xnum-1):
39
+ totalPatch[j, i] = img[y:int(y + patchRes), -patchRes:]
40
+ else:
41
+ totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)]
42
+
43
+ return totalPatch, padH, padW
44
+
45
+
46
+ def illCorrection(model, totalPatch):
47
+ totalPatch = totalPatch.astype(np.float32) / 255.0
48
+
49
+ ynum = totalPatch.shape[0]
50
+ xnum = totalPatch.shape[1]
51
+
52
+ totalResults = np.zeros((ynum, xnum, 128, 128, 3), dtype=np.float32)
53
+
54
+ for j in range(0, ynum):
55
+ for i in range(0, xnum):
56
+ patchImg = torch.from_numpy(totalPatch[j, i]).permute(2,0,1)
57
+ patchImg = patchImg.cuda().view(1, 3, 128, 128)
58
+
59
+ output = model(patchImg)
60
+ output = output.permute(0, 2, 3, 1).data.cpu().numpy()[0]
61
+
62
+ output = output * 255.0
63
+ output = output.astype(np.uint8)
64
+
65
+ totalResults[j, i] = output
66
+
67
+ return totalResults
68
+
69
+
70
+ def composePatch(totalResults, padH, padW, img):
71
+ ynum = totalResults.shape[0]
72
+ xnum = totalResults.shape[1]
73
+ patchRes = totalResults.shape[2]
74
+
75
+ ovlp = int(patchRes * 0.125)
76
+ step = patchRes - ovlp
77
+
78
+ resImg = np.zeros((patchRes + (ynum - 1) * step, patchRes + (xnum - 1) * step, 3), np.uint8)
79
+ resImg = np.zeros_like(img).astype('uint8')
80
+
81
+ for j in range(0, ynum):
82
+ for i in range(0, xnum):
83
+ sy = int(j * step)
84
+ sx = int(i * step)
85
+
86
+ if j == 0 and i != (xnum-1):
87
+ resImg[sy:(sy + patchRes), sx:(sx + patchRes)] = totalResults[j, i]
88
+ elif i == 0 and j != (ynum-1):
89
+ resImg[sy+10:(sy + patchRes), sx:(sx + patchRes)] = totalResults[j, i,10:]
90
+ elif j == (ynum-1) and i == (xnum-1):
91
+ resImg[-patchRes+10:, -patchRes+10:] = totalResults[j, i,10:,10:]
92
+ elif j == (ynum-1) and i == 0:
93
+ resImg[-patchRes+10:, sx:(sx + patchRes)] = totalResults[j, i,10:]
94
+ elif j == (ynum-1) and i != 0:
95
+ resImg[-patchRes+10:, sx+10:(sx + patchRes)] = totalResults[j, i,10:,10:]
96
+ elif i == (xnum-1) and j == 0:
97
+ resImg[sy:(sy + patchRes), -patchRes+10:] = totalResults[j, i,:,10:]
98
+ elif i == (xnum-1) and j != 0:
99
+ resImg[sy+10:(sy + patchRes), -patchRes+10:] = totalResults[j, i,10:,10:]
100
+ else:
101
+ resImg[sy+10:(sy + patchRes), sx+10:(sx + patchRes)] = totalResults[j, i,10:,10:]
102
+
103
+ resImg[0,:,:] = 255
104
+
105
+ return resImg
106
+
107
+
108
+ def preProcess(img):
109
+ img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10)
110
+ img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10)
111
+ img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10)
112
+
113
+ return img
114
+
115
+
116
+ def postProcess(img):
117
+ img = Image.fromarray(img)
118
+ enhancer = ImageEnhance.Contrast(img)
119
+ factor = 2.0
120
+ img = enhancer.enhance(factor)
121
+
122
+ return img
123
+
124
+
125
+ def rec_ill(net, img, saveRecPath):
126
+
127
+ totalPatch, padH, padW = padCropImg(img)
128
+
129
+ totalResults = illCorrection(net, totalPatch)
130
+
131
+ resImg = composePatch(totalResults, padH, padW, img)
132
+ #resImg = postProcess(resImg)
133
+ resImg = Image.fromarray(resImg)
134
+ resImg.save(saveRecPath)