qubvel-hf HF staff commited on
Commit
c509e76
1 Parent(s): b012b87

Init project

Browse files
Files changed (41) hide show
  1. README.md +59 -0
  2. data/MBD/MBD.py +110 -0
  3. data/MBD/MBD_utils.py +291 -0
  4. data/MBD/infer.py +151 -0
  5. data/MBD/model/__init__.py +50 -0
  6. data/MBD/model/cbam.py +95 -0
  7. data/MBD/model/deep_lab_model/__init__.py +0 -0
  8. data/MBD/model/deep_lab_model/aspp.py +95 -0
  9. data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  10. data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  11. data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  12. data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  13. data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  14. data/MBD/model/deep_lab_model/decoder.py +59 -0
  15. data/MBD/model/deep_lab_model/deeplab.py +81 -0
  16. data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  17. data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  18. data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  19. data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  20. data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  21. data/MBD/model/densenetccnl.py +382 -0
  22. data/MBD/model/gienet.py +742 -0
  23. data/MBD/model/unetnc.py +86 -0
  24. data/MBD/modify_stn_model/stn_head.py +123 -0
  25. data/MBD/modify_stn_model/tps_spatial_transformer.py +194 -0
  26. data/MBD/stn_model/stn_head.py +123 -0
  27. data/MBD/stn_model/tps_spatial_transformer.py +155 -0
  28. data/MBD/tps_grid_gen.py +70 -0
  29. data/MBD/utils.py +234 -0
  30. data/README.md +135 -0
  31. data/preprocess/crop_merge_image.py +142 -0
  32. data/preprocess/sauvola_binarize.py +91 -0
  33. data/preprocess/shadow_extraction.py +68 -0
  34. eval.py +369 -0
  35. inference.py +341 -0
  36. loaders/docres_loader.py +558 -0
  37. models/restormer_arch.py +308 -0
  38. requirements.txt +10 -0
  39. start_train.sh +1 -0
  40. train.py +221 -0
  41. utils.py +464 -0
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <div align=center>
3
+
4
+ # DocRes: A Generalist Model Toward Unifying Document Image Restoration Tasks
5
+
6
+ </div>
7
+
8
+ <p align="center">
9
+ <img src="images/motivation.jpg" width="400">
10
+ </p>
11
+
12
+ This is the official implementation of our paper [DocRes: A Generalist Model Toward Unifying Document Image Restoration Tasks](https://arxiv.org/abs/2405.04408).
13
+
14
+ ## News
15
+ 🔥 A comprehensive [Recommendation for Document Image Processing](https://github.com/ZZZHANG-jx/Recommendations-Document-Image-Processing) is available.
16
+
17
+
18
+ ## Inference
19
+ 1. Put MBD model weights [mbd.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `./data/MBD/checkpoint/`
20
+ 2. Put DocRes model weights [docres.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `./checkpoints/`
21
+ 3. Run the following script and the results will be saved in `./restorted/`. We have provided some distorted examples in `./input/`.
22
+ ```bash
23
+ python inference.py --im_path ./input/for_dewarping.png --task dewarping --save_dtsprompt 1
24
+ ```
25
+
26
+ - `--im_path`: the path of input document image
27
+ - `--task`: task that need to be executed, it must be one of _dewarping_, _deshadowing_, _appearance_, _deblurring_, _binarization_, or _end2end_
28
+ - `--save_dtsprompt`: whether to save the DTSPrompt
29
+
30
+ ## Evaluation
31
+
32
+ 1. Dataset preparation, see [dataset instruction](./data/README.md)
33
+ 2. Put MBD model weights [mbd.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `data/MBD/checkpoint/`
34
+ 3. Put DocRes model weights [docres.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `./checkpoints/`
35
+ 2. Run the following script
36
+ ```bash
37
+ python eval.py --dataset realdae
38
+ ```
39
+ - `--dataset`: dataset that need to be evaluated, it can be set as _dir300_, _kligler_, _jung_, _osr_, _docunet\_docaligner_, _realdae_, _tdd_, and _dibco18_.
40
+
41
+ ## Training
42
+ 1. Dataset preparation, see [dataset instruction](./data/README.md)
43
+ 2. Specify the datasets_setting within `train.py` based on your dataset path and experimental setting.
44
+ 3. Run the following script
45
+ ```bash
46
+ bash start_train.sh
47
+ ```
48
+
49
+
50
+ ## Citation:
51
+ ```
52
+ @inproceedings{zhangdocres2024,
53
+ Author = {Jiaxin Zhang, Dezhi Peng, Chongyu Liu , Peirong Zhang and Lianwen Jin},
54
+ Booktitle = {In Proceedings of the IEEE/CV Conference on Computer Vision and Pattern Recognition},
55
+ Title = {DocRes: A Generalist Model Toward Unifying Document Image Restoration Tasks},
56
+ Year = {2024}}
57
+ ```
58
+ ## ⭐ Star Rising
59
+ [![Star Rising](https://api.star-history.com/svg?repos=ZZZHANG-jx/DocRes&type=Timeline)](https://star-history.com/#ZZZHANG-jx/DocRes&Timeline)
data/MBD/MBD.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import MBD_utils
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def mask_base_dewarper(image,mask):
9
+ '''
10
+ input:
11
+ image -> ndarray HxWx3 uint8
12
+ mask -> ndarray HxW uint8
13
+ return
14
+ dewarped -> ndarray HxWx3 uint8
15
+ grid (optional) -> ndarray HxWx2 -1~1
16
+ '''
17
+
18
+ ## get contours
19
+ # _, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
20
+ contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
21
+
22
+ ## get biggest contours and four corners based on Douglas-Peucker algorithm
23
+ four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
24
+ four_corners = MBD_utils.reorder(four_corners)
25
+
26
+ ## reserve biggest contours and remove other noisy contours
27
+ new_mask = np.zeros_like(mask)
28
+ new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
29
+
30
+ ## obtain middle points
31
+ # ratios = [0.25,0.5,0.75] # ratios = [0.125,0.25,0.375,0.5,0.625,0.75,0.875]
32
+ ratios = [0.25,0.5,0.75]
33
+ # ratios = [0.0625,0.125,0.1875,0.25,0.3125,0.375,0.4475,0.5,0.5625,0.625,0.06875,0.75,0.8125,0.875,0.9375]
34
+ middle = MBD_utils.findMiddle(corners=four_corners,mask=new_mask,points=ratios)
35
+
36
+ ## all points
37
+ source_points = np.concatenate((four_corners,middle),axis=0) ## all_point = four_corners(topleft,topright,bottom)+top+bottom+left+right
38
+
39
+ ## target points
40
+ h,w = image.shape[:2]
41
+ padding = 0
42
+ target_points = [[padding, padding],[w-padding, padding], [padding, h-padding],[w-padding, h-padding]]
43
+ for ratio in ratios:
44
+ target_points.append([int((w-2*padding)*ratio)+padding,padding])
45
+ for ratio in ratios:
46
+ target_points.append([int((w-2*padding)*ratio)+padding,h-padding])
47
+ for ratio in ratios:
48
+ target_points.append([padding,int((h-2*padding)*ratio)+padding])
49
+ for ratio in ratios:
50
+ target_points.append([w-padding,int((h-2*padding)*ratio)+padding])
51
+
52
+ ## dewarp base on cv2
53
+ # pts1 = np.float32(source_points)
54
+ # pts2 = np.float32(target_points)
55
+ # tps = cv2.createThinPlateSplineShapeTransformer()
56
+ # matches = []
57
+ # N = pts1.shape[0]
58
+ # for i in range(0,N):
59
+ # matches.append(cv2.DMatch(i,i,0))
60
+ # pts1 = pts1.reshape(1,-1,2)
61
+ # pts2 = pts2.reshape(1,-1,2)
62
+ # tps.estimateTransformation(pts2,pts1,matches)
63
+ # dewarped = tps.warpImage(image)
64
+
65
+ ## dewarp base on generated grid
66
+ source_points = source_points.reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
67
+ source_points = torch.from_numpy(source_points).float().cuda()
68
+ source_points = source_points.unsqueeze(0)
69
+ source_points = (source_points-0.5)*2
70
+ target_points = np.asarray(target_points).reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
71
+ target_points = torch.from_numpy(target_points).float()
72
+ target_points = (target_points-0.5)*2
73
+
74
+ model = MBD_utils.TPSGridGen(target_height=256,target_width=256,target_control_points=target_points)
75
+ model = model.cuda()
76
+ grid = model(source_points).view(-1,256,256,2).permute(0,3,1,2)
77
+ grid = F.interpolate(grid,(h,w),mode='bilinear').permute(0,2,3,1)
78
+ dewarped = MBD_utils.torch2cvimg(F.grid_sample(MBD_utils.cvimg2torch(image).cuda(),grid))[0]
79
+ return dewarped,grid[0].cpu().numpy()
80
+
81
+ def mask_base_cropper(image,mask):
82
+ '''
83
+ input:
84
+ image -> ndarray HxWx3 uint8
85
+ mask -> ndarray HxW uint8
86
+ return
87
+ dewarped -> ndarray HxWx3 uint8
88
+ grid (optional) -> ndarray HxWx2 -1~1
89
+ '''
90
+
91
+ ## get contours
92
+ _, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
93
+ # contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
94
+
95
+ ## get biggest contours and four corners based on Douglas-Peucker algorithm
96
+ four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
97
+ four_corners = MBD_utils.reorder(four_corners)
98
+
99
+ ## reserve biggest contours and remove other noisy contours
100
+ new_mask = np.zeros_like(mask)
101
+ new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
102
+
103
+ ## 最小外接矩形
104
+ rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
105
+ box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
106
+ box = np.int0(box)
107
+ box = box.reshape((4,1,2))
108
+
109
+
110
+
data/MBD/MBD_utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import copy
4
+ import torch
5
+ import torch
6
+ import itertools
7
+ import torch.nn as nn
8
+ from torch.autograd import Function, Variable
9
+
10
+ def reorder(myPoints):
11
+ myPoints = myPoints.reshape((4, 2))
12
+ myPointsNew = np.zeros((4, 1, 2), dtype=np.int32)
13
+ add = myPoints.sum(1)
14
+ myPointsNew[0] = myPoints[np.argmin(add)]
15
+ myPointsNew[3] =myPoints[np.argmax(add)]
16
+ diff = np.diff(myPoints, axis=1)
17
+ myPointsNew[1] =myPoints[np.argmin(diff)]
18
+ myPointsNew[2] = myPoints[np.argmax(diff)]
19
+ return myPointsNew
20
+
21
+
22
+ def findMiddle(corners,mask,points=[0.25,0.5,0.75]):
23
+ num_middle_points = len(points)
24
+ top = [np.array([])]*num_middle_points
25
+ bottom = [np.array([])]*num_middle_points
26
+ left = [np.array([])]*num_middle_points
27
+ right = [np.array([])]*num_middle_points
28
+
29
+ center_top = []
30
+ center_bottom = []
31
+ center_left = []
32
+ center_right = []
33
+
34
+ center = (int((corners[0][0][1]+corners[3][0][1])/2),int((corners[0][0][0]+corners[3][0][0])/2))
35
+ for ratio in points:
36
+
37
+ center_top.append( (center[0],int(corners[0][0][0]*(1-ratio)+corners[1][0][0]*ratio)) )
38
+
39
+ center_bottom.append( (center[0],int(corners[2][0][0]*(1-ratio)+corners[3][0][0]*ratio)) )
40
+
41
+ center_left.append( (int(corners[0][0][1]*(1-ratio)+corners[2][0][1]*ratio),center[1]) )
42
+
43
+ center_right.append( (int(corners[1][0][1]*(1-ratio)+corners[3][0][1]*ratio),center[1]) )
44
+
45
+ for i in range(0,center[0],1):
46
+ for j in range(num_middle_points):
47
+ if top[j].size==0:
48
+ if mask[i,center_top[j][1]]==255:
49
+ top[j] = np.asarray([center_top[j][1],i])
50
+ top[j] = top[j].reshape(1,2)
51
+
52
+ for i in range(mask.shape[0]-1,center[0],-1):
53
+ for j in range(num_middle_points):
54
+ if bottom[j].size==0:
55
+ if mask[i,center_bottom[j][1]]==255:
56
+ bottom[j] = np.asarray([center_bottom[j][1],i])
57
+ bottom[j] = bottom[j].reshape(1,2)
58
+
59
+ for i in range(mask.shape[1]-1,center[1],-1):
60
+ for j in range(num_middle_points):
61
+ if right[j].size==0:
62
+ if mask[center_right[j][0],i]==255:
63
+ right[j] = np.asarray([i,center_right[j][0]])
64
+ right[j] = right[j].reshape(1,2)
65
+
66
+ for i in range(0,center[1]):
67
+ for j in range(num_middle_points):
68
+ if left[j].size==0:
69
+ if mask[center_left[j][0],i]==255:
70
+ left[j] = np.asarray([i,center_left[j][0]])
71
+ left[j] = left[j].reshape(1,2)
72
+
73
+ return np.asarray(top+bottom+left+right)
74
+
75
+ def DP_algorithmv1(contours):
76
+ biggest = np.array([])
77
+ max_area = 0
78
+ step = 0.001
79
+ count = 0
80
+ # while biggest.size==0:
81
+ while True:
82
+ for i in contours:
83
+ # print(i.shape)
84
+ area = cv2.contourArea(i)
85
+ # print(area,cv2.arcLength(i, True))
86
+ if area > cv2.arcLength(i, True)*10:
87
+ peri = cv2.arcLength(i, True)
88
+ approx = cv2.approxPolyDP(i, (0.01+step*count) * peri, True)
89
+ if area > max_area and len(approx) == 4:
90
+ max_area = area
91
+ biggest_contours = i
92
+ biggest = approx
93
+ break
94
+ if abs(max_area - cv2.contourArea(biggest))/max_area > 0.3:
95
+ biggest = np.array([])
96
+ count += 1
97
+ if count > 200:
98
+ break
99
+ temp = biggest[0]
100
+ return biggest,max_area, biggest_contours
101
+
102
+ def DP_algorithm(contours):
103
+ biggest = np.array([])
104
+ max_area = 0
105
+ step = 0.001
106
+ count = 0
107
+
108
+ ### largest contours
109
+ for i in contours:
110
+ area = cv2.contourArea(i)
111
+ if area > max_area:
112
+ max_area = area
113
+ biggest_contours = i
114
+ peri = cv2.arcLength(biggest_contours, True)
115
+
116
+ ### find four corners
117
+ while True:
118
+ approx = cv2.approxPolyDP(biggest_contours, (0.01+step*count) * peri, True)
119
+ if len(approx) == 4:
120
+ biggest = approx
121
+ break
122
+ # if abs(max_area - cv2.contourArea(biggest))/max_area > 0.2:
123
+ # if abs(max_area - cv2.contourArea(biggest))/max_area > 0.4:
124
+ # biggest = np.array([])
125
+ count += 1
126
+ if count > 200:
127
+ break
128
+ return biggest,max_area, biggest_contours
129
+
130
+ def drawRectangle(img,biggest,color,thickness):
131
+ cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
132
+ cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
133
+ cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
134
+ cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
135
+ return img
136
+
137
+ def minAreaRect(contours,img):
138
+ # biggest = np.array([])
139
+ max_area = 0
140
+ for i in contours:
141
+ area = cv2.contourArea(i)
142
+ if area > max_area:
143
+ peri = cv2.arcLength(i, True)
144
+ rect = cv2.minAreaRect(i)
145
+ points = cv2.boxPoints(rect)
146
+ max_area = area
147
+ return points
148
+
149
+ def cropRectangle(img,biggest):
150
+ # print(biggest)
151
+ w = np.abs(biggest[0][0][0] - biggest[1][0][0])
152
+ h = np.abs(biggest[0][0][1] - biggest[2][0][1])
153
+ new_img = np.zeros((w,h,img.shape[-1]),dtype=np.uint8)
154
+ new_img = img[biggest[0][0][1]:biggest[0][0][1]+h,biggest[0][0][0]:biggest[0][0][0]+w]
155
+ return new_img
156
+
157
+ def cvimg2torch(img,min=0,max=1):
158
+ '''
159
+ input:
160
+ im -> ndarray uint8 HxWxC
161
+ return
162
+ tensor -> torch.tensor BxCxHxW
163
+ '''
164
+ if len(img.shape)==2:
165
+ img = np.expand_dims(img,axis=-1)
166
+ img = img.astype(float) / 255.0
167
+ img = img.transpose(2, 0, 1) # NHWC -> NCHW
168
+ img = np.expand_dims(img, 0)
169
+ img = torch.from_numpy(img).float()
170
+ return img
171
+
172
+ def torch2cvimg(tensor,min=0,max=1):
173
+ '''
174
+ input:
175
+ tensor -> torch.tensor BxCxHxW C can be 1,3
176
+ return
177
+ im -> ndarray uint8 HxWxC
178
+ '''
179
+ im_list = []
180
+ for i in range(tensor.shape[0]):
181
+ im = tensor.detach().cpu().data.numpy()[i]
182
+ im = im.transpose(1,2,0)
183
+ im = np.clip(im,min,max)
184
+ im = ((im-min)/(max-min)*255).astype(np.uint8)
185
+ im_list.append(im)
186
+ return im_list
187
+
188
+
189
+
190
+ class TPSGridGen(nn.Module):
191
+ def __init__(self, target_height, target_width, target_control_points):
192
+ '''
193
+ target_control_points -> torch.tensor num_pointx2 -1~1
194
+ source_control_points -> torch.tensor batch_size x num_point x 2 -1~1
195
+ return:
196
+ grid -> batch_size x hw x 2 -1~1
197
+ '''
198
+ super(TPSGridGen, self).__init__()
199
+ assert target_control_points.ndimension() == 2
200
+ assert target_control_points.size(1) == 2
201
+ N = target_control_points.size(0)
202
+ self.num_points = N
203
+ target_control_points = target_control_points.float()
204
+
205
+ # create padded kernel matrix
206
+ forward_kernel = torch.zeros(N + 3, N + 3)
207
+ target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
208
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
209
+ forward_kernel[:N, -3].fill_(1)
210
+ forward_kernel[-3, :N].fill_(1)
211
+ forward_kernel[:N, -2:].copy_(target_control_points)
212
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
213
+ # compute inverse matrix
214
+ inverse_kernel = torch.inverse(forward_kernel)
215
+
216
+ # create target cordinate matrix
217
+ HW = target_height * target_width
218
+ target_coordinate = list(itertools.product(range(target_height), range(target_width)))
219
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
220
+ Y, X = target_coordinate.split(1, dim = 1)
221
+ Y = Y * 2 / (target_height - 1) - 1
222
+ X = X * 2 / (target_width - 1) - 1
223
+ target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
224
+ target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate.to(target_control_points.device), target_control_points)
225
+ target_coordinate_repr = torch.cat([
226
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
227
+ ], dim = 1)
228
+
229
+ # register precomputed matrices
230
+ self.register_buffer('inverse_kernel', inverse_kernel)
231
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
232
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
233
+
234
+ def forward(self, source_control_points):
235
+ assert source_control_points.ndimension() == 3
236
+ assert source_control_points.size(1) == self.num_points
237
+ assert source_control_points.size(2) == 2
238
+ batch_size = source_control_points.size(0)
239
+
240
+ Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
241
+ mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
242
+ source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
243
+ return source_coordinate
244
+ # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
245
+ def compute_partial_repr(self, input_points, control_points):
246
+ N = input_points.size(0)
247
+ M = control_points.size(0)
248
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
249
+ # original implementation, very slow
250
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
251
+ pairwise_diff_square = pairwise_diff * pairwise_diff
252
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
253
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
254
+ # fix numerical error for 0 * log(0), substitute all nan with 0
255
+ mask = repr_matrix != repr_matrix
256
+ repr_matrix.masked_fill_(mask, 0)
257
+ return repr_matrix
258
+
259
+
260
+
261
+
262
+
263
+ ### deside wheather further process
264
+ # point_area = cv2.contourArea(np.concatenate((biggest_angle[0].reshape(1,1,2),middle[0:3],biggest_angle[1].reshape(1,1,2),middle[9:12],biggest_angle[3].reshape(1,1,2),middle[3:6][::-1],biggest_angle[2].reshape(1,1,2),middle[6:9][::-1]),axis=0))
265
+ #### 最小外接矩形
266
+ # rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
267
+ # box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
268
+ # box = np.int0(box)
269
+ # box = box.reshape((4,1,2))
270
+ # minrect_area = cv2.contourArea(box)
271
+ # print(abs(minrect_area-point_area)/point_area)
272
+ #### 四个角点 IOU
273
+ # biggest_box = np.concatenate((biggest_angle[0,:,:].reshape(1,1,2),biggest_angle[2,:,:].reshape(1,1,2),biggest_angle[3,:,:].reshape(1,1,2),biggest_angle[1,:,:].reshape(1,1,2)),axis=0)
274
+ # biggest_mask = np.zeros_like(mask)
275
+ # # corner_area = cv2.contourArea(biggest_box)
276
+ # cv2.drawContours(biggest_mask,[biggest_box], -1, color=255, thickness=-1)
277
+
278
+ # smooth = 1e-5
279
+ # biggest_mask_ = biggest_mask > 50
280
+ # mask_ = mask > 50
281
+ # intersection = (biggest_mask_ & mask_).sum()
282
+ # union = (biggest_mask_ | mask_).sum()
283
+ # iou = (intersection + smooth) / (union + smooth)
284
+ # if iou > 0.975:
285
+ # skip = True
286
+ # else:
287
+ # skip = False
288
+ # print(iou)
289
+ # cv2.imshow('mask',cv2.resize(mask,(512,512)))
290
+ # cv2.imshow('biggest_mask',cv2.resize(biggest_mask,(512,512)))
291
+ # cv2.waitKey(0)
data/MBD/infer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import glob
6
+ import cv2
7
+ from tqdm import tqdm
8
+
9
+ import time
10
+ import os
11
+ from model.deep_lab_model.deeplab import *
12
+ from MBD import mask_base_dewarper
13
+ import time
14
+
15
+ from utils import cvimg2torch,torch2cvimg
16
+
17
+
18
+
19
+ def net1_net2_infer(model,img_paths,args):
20
+
21
+ ### validate on the real datasets
22
+ seg_model=model
23
+ seg_model.eval()
24
+ for img_path in tqdm(img_paths):
25
+ if os.path.exists(img_path.replace('_origin','_capture')):
26
+ continue
27
+ t1 = time.time()
28
+ ### segmentation mask predict
29
+ img_org = cv2.imread(img_path)
30
+ h_org,w_org = img_org.shape[:2]
31
+ img = cv2.resize(img_org,(448, 448))
32
+ img = cv2.GaussianBlur(img,(15,15),0,0)
33
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
34
+ img = cvimg2torch(img)
35
+
36
+ with torch.no_grad():
37
+ pred = seg_model(img.cuda())
38
+ mask_pred = pred[:,0,:,:].unsqueeze(1)
39
+ mask_pred = F.interpolate(mask_pred,(h_org,w_org))
40
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
41
+ mask_pred = (mask_pred*255).astype(np.uint8)
42
+ kernel = np.ones((3,3))
43
+ mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
44
+ mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
45
+ mask_pred[mask_pred>100] = 255
46
+ mask_pred[mask_pred<100] = 0
47
+ ### tps transform base on the mask
48
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
49
+ try:
50
+ dewarp, grid = mask_base_dewarper(img_org,mask_pred)
51
+ except:
52
+ print('fail')
53
+ grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
54
+ grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
55
+ dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
56
+ grid = grid[0].numpy()
57
+ # cv2.imshow('in',cv2.resize(img_org,(512,512)))
58
+ # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
59
+ # cv2.waitKey(0)
60
+ cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
61
+ cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
62
+
63
+ grid0 = cv2.resize(grid[:,:,0],(128,128))
64
+ grid1 = cv2.resize(grid[:,:,1],(128,128))
65
+ grid = np.stack((grid0,grid1),axis=-1)
66
+ np.save(img_path.replace('_origin','_grid1'),grid)
67
+
68
+
69
+ def net1_net2_infer_single_im(img,model_path):
70
+ seg_model = DeepLab(num_classes=1,
71
+ backbone='resnet',
72
+ output_stride=16,
73
+ sync_bn=None,
74
+ freeze_bn=False)
75
+ seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
76
+ seg_model.cuda()
77
+ checkpoint = torch.load(model_path)
78
+ seg_model.load_state_dict(checkpoint['model_state'])
79
+ ### validate on the real datasets
80
+ seg_model.eval()
81
+ ### segmentation mask predict
82
+ img_org = img
83
+ h_org,w_org = img_org.shape[:2]
84
+ img = cv2.resize(img_org,(448, 448))
85
+ img = cv2.GaussianBlur(img,(15,15),0,0)
86
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
87
+ img = cvimg2torch(img)
88
+
89
+ with torch.no_grad():
90
+ # from torchtoolbox.tools import summary
91
+ # print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
92
+
93
+ pred = seg_model(img.cuda())
94
+ mask_pred = pred[:,0,:,:].unsqueeze(1)
95
+ mask_pred = F.interpolate(mask_pred,(h_org,w_org))
96
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
97
+ mask_pred = (mask_pred*255).astype(np.uint8)
98
+ kernel = np.ones((3,3))
99
+ mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
100
+ mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
101
+ mask_pred[mask_pred>100] = 255
102
+ mask_pred[mask_pred<100] = 0
103
+ ### tps transform base on the mask
104
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
105
+ # try:
106
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
107
+ # except:
108
+ # print('fail')
109
+ # grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
110
+ # grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
111
+ # dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
112
+ # grid = grid[0].numpy()
113
+ # cv2.imshow('in',cv2.resize(img_org,(512,512)))
114
+ # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
115
+ # cv2.waitKey(0)
116
+ # cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
117
+ # cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
118
+
119
+ # grid0 = cv2.resize(grid[:,:,0],(128,128))
120
+ # grid1 = cv2.resize(grid[:,:,1],(128,128))
121
+ # grid = np.stack((grid0,grid1),axis=-1)
122
+ # np.save(img_path.replace('_origin','_grid1'),grid)
123
+ return mask_pred
124
+
125
+
126
+
127
+ if __name__ == '__main__':
128
+ parser = argparse.ArgumentParser(description='Hyperparams')
129
+ parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
130
+ parser.add_argument('--img_rows', nargs='?', type=int, default=448,
131
+ help='Height of the input image')
132
+ parser.add_argument('--img_cols', nargs='?', type=int, default=448,
133
+ help='Width of the input image')
134
+ parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
135
+ help='Path to previous saved model to restart from')
136
+ args = parser.parse_args()
137
+
138
+ seg_model = DeepLab(num_classes=1,
139
+ backbone='resnet',
140
+ output_stride=16,
141
+ sync_bn=None,
142
+ freeze_bn=False)
143
+ seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
144
+ seg_model.cuda()
145
+ checkpoint = torch.load(args.seg_model_path)
146
+ seg_model.load_state_dict(checkpoint['model_state'])
147
+
148
+ im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
149
+
150
+ net1_net2_infer(seg_model,im_paths,args)
151
+
data/MBD/model/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.models as models
2
+ from model.densenetccnl import *
3
+ from model.unetnc import *
4
+ from model.gienet import *
5
+
6
+
7
+ def get_model(name, n_classes=1, filters=64,version=None,in_channels=3, is_batchnorm=True, norm='batch', model_path=None, use_sigmoid=True, layers=3,img_size=512):
8
+ model = _get_model_instance(name)
9
+
10
+
11
+ if name == 'dnetccnl':
12
+ model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32)
13
+ elif name == 'dnetccnl512':
14
+ model = model(img_size=img_size, in_channels=in_channels, out_channels=n_classes, filters=32)
15
+ elif name == 'unetnc':
16
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
17
+ elif name == 'gie':
18
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
19
+ elif name == 'giecbam':
20
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
21
+ elif name == 'gie2head':
22
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
23
+ elif name == 'giemask':
24
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
25
+ elif name == 'giemask2':
26
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
27
+ elif name == 'giedilated':
28
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
29
+ elif name == 'bmp':
30
+ model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
31
+ elif name == 'displacement':
32
+ model = model(n_classes=2, num_filter=32, BatchNorm='GN', in_channels=5)
33
+ return model
34
+
35
+ def _get_model_instance(name):
36
+ try:
37
+ return {
38
+ 'dnetccnl': dnetccnl,
39
+ 'dnetccnl512': dnetccnl512,
40
+ 'unetnc': UnetGenerator,
41
+ 'gie':GieGenerator,
42
+ 'giecbam':GiecbamGenerator,
43
+ 'giedilated':DilatedSingleUnet,
44
+ 'gie2head':Gie2headGenerator,
45
+ 'giemask':GiemaskGenerator,
46
+ 'giemask2':Giemask2Generator,
47
+ 'bmp':BmpGenerator,
48
+ }[name]
49
+ except:
50
+ print('Model {} not available'.format(name))
data/MBD/model/cbam.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class BasicConv(nn.Module):
7
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8
+ super(BasicConv, self).__init__()
9
+ self.out_channels = out_planes
10
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11
+ self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12
+ self.relu = nn.ReLU() if relu else None
13
+
14
+ def forward(self, x):
15
+ x = self.conv(x)
16
+ if self.bn is not None:
17
+ x = self.bn(x)
18
+ if self.relu is not None:
19
+ x = self.relu(x)
20
+ return x
21
+
22
+ class Flatten(nn.Module):
23
+ def forward(self, x):
24
+ return x.view(x.size(0), -1)
25
+
26
+ class ChannelGate(nn.Module):
27
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28
+ super(ChannelGate, self).__init__()
29
+ self.gate_channels = gate_channels
30
+ self.mlp = nn.Sequential(
31
+ Flatten(),
32
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
33
+ nn.ReLU(),
34
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
35
+ )
36
+ self.pool_types = pool_types
37
+ def forward(self, x):
38
+ channel_att_sum = None
39
+ for pool_type in self.pool_types:
40
+ if pool_type=='avg':
41
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42
+ channel_att_raw = self.mlp( avg_pool )
43
+ elif pool_type=='max':
44
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45
+ channel_att_raw = self.mlp( max_pool )
46
+ elif pool_type=='lp':
47
+ lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48
+ channel_att_raw = self.mlp( lp_pool )
49
+ elif pool_type=='lse':
50
+ # LSE pool only
51
+ lse_pool = logsumexp_2d(x)
52
+ channel_att_raw = self.mlp( lse_pool )
53
+
54
+ if channel_att_sum is None:
55
+ channel_att_sum = channel_att_raw
56
+ else:
57
+ channel_att_sum = channel_att_sum + channel_att_raw
58
+
59
+ scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60
+ return x * scale
61
+
62
+ def logsumexp_2d(tensor):
63
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66
+ return outputs
67
+
68
+ class ChannelPool(nn.Module):
69
+ def forward(self, x):
70
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71
+
72
+ class SpatialGate(nn.Module):
73
+ def __init__(self):
74
+ super(SpatialGate, self).__init__()
75
+ kernel_size = 7
76
+ self.compress = ChannelPool()
77
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78
+ def forward(self, x):
79
+ x_compress = self.compress(x)
80
+ x_out = self.spatial(x_compress)
81
+ scale = F.sigmoid(x_out) # broadcasting
82
+ return x * scale
83
+
84
+ class CBAM(nn.Module):
85
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86
+ super(CBAM, self).__init__()
87
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88
+ self.no_spatial=no_spatial
89
+ if not no_spatial:
90
+ self.SpatialGate = SpatialGate()
91
+ def forward(self, x):
92
+ x_out = self.ChannelGate(x)
93
+ if not self.no_spatial:
94
+ x_out = self.SpatialGate(x_out)
95
+ return x_out
data/MBD/model/deep_lab_model/__init__.py ADDED
File without changes
data/MBD/model/deep_lab_model/aspp.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+
7
+ class _ASPPModule(nn.Module):
8
+ def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9
+ super(_ASPPModule, self).__init__()
10
+ self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11
+ stride=1, padding=padding, dilation=dilation, bias=False)
12
+ self.bn = BatchNorm(planes)
13
+ self.relu = nn.ReLU()
14
+
15
+ self._init_weight()
16
+
17
+ def forward(self, x):
18
+ x = self.atrous_conv(x)
19
+ x = self.bn(x)
20
+
21
+ return self.relu(x)
22
+
23
+ def _init_weight(self):
24
+ for m in self.modules():
25
+ if isinstance(m, nn.Conv2d):
26
+ torch.nn.init.kaiming_normal_(m.weight)
27
+ elif isinstance(m, SynchronizedBatchNorm2d):
28
+ m.weight.data.fill_(1)
29
+ m.bias.data.zero_()
30
+ elif isinstance(m, nn.BatchNorm2d):
31
+ m.weight.data.fill_(1)
32
+ m.bias.data.zero_()
33
+
34
+ class ASPP(nn.Module):
35
+ def __init__(self, backbone, output_stride, BatchNorm):
36
+ super(ASPP, self).__init__()
37
+ if backbone == 'drn':
38
+ inplanes = 512
39
+ elif backbone == 'mobilenet':
40
+ inplanes = 320
41
+ else:
42
+ inplanes = 2048
43
+ if output_stride == 16:
44
+ dilations = [1, 6, 12, 18]
45
+ elif output_stride == 8:
46
+ dilations = [1, 12, 24, 36]
47
+ else:
48
+ raise NotImplementedError
49
+
50
+ self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51
+ self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52
+ self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53
+ self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54
+
55
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56
+ nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57
+ BatchNorm(256),
58
+ nn.ReLU())
59
+ self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60
+ self.bn1 = BatchNorm(256)
61
+ self.relu = nn.ReLU()
62
+ self.dropout = nn.Dropout(0.5)
63
+ self._init_weight()
64
+
65
+ def forward(self, x):
66
+ x1 = self.aspp1(x)
67
+ x2 = self.aspp2(x)
68
+ x3 = self.aspp3(x)
69
+ x4 = self.aspp4(x)
70
+ x5 = self.global_avg_pool(x)
71
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73
+
74
+ x = self.conv1(x)
75
+ x = self.bn1(x)
76
+ x = self.relu(x)
77
+
78
+ return self.dropout(x)
79
+
80
+ def _init_weight(self):
81
+ for m in self.modules():
82
+ if isinstance(m, nn.Conv2d):
83
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
85
+ torch.nn.init.kaiming_normal_(m.weight)
86
+ elif isinstance(m, SynchronizedBatchNorm2d):
87
+ m.weight.data.fill_(1)
88
+ m.bias.data.zero_()
89
+ elif isinstance(m, nn.BatchNorm2d):
90
+ m.weight.data.fill_(1)
91
+ m.bias.data.zero_()
92
+
93
+
94
+ def build_aspp(backbone, output_stride, BatchNorm):
95
+ return ASPP(backbone, output_stride, BatchNorm)
data/MBD/model/deep_lab_model/backbone/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.deep_lab_model.backbone import resnet, xception, drn, mobilenet
2
+
3
+ def build_backbone(backbone, output_stride, BatchNorm):
4
+ if backbone == 'resnet':
5
+ return resnet.ResNet101(output_stride, BatchNorm)
6
+ elif backbone == 'xception':
7
+ return xception.AlignedXception(output_stride, BatchNorm)
8
+ elif backbone == 'drn':
9
+ return drn.drn_d_54(BatchNorm)
10
+ elif backbone == 'mobilenet':
11
+ return mobilenet.MobileNetV2(output_stride, BatchNorm)
12
+ else:
13
+ raise NotImplementedError
data/MBD/model/deep_lab_model/backbone/drn.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ import torch.utils.model_zoo as model_zoo
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+
6
+ webroot = 'http://dl.yf.io/drn/'
7
+
8
+ model_urls = {
9
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10
+ 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
11
+ 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
12
+ 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
13
+ 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
14
+ 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
15
+ 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
16
+ 'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=padding, bias=False, dilation=dilation)
23
+
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+
28
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
29
+ dilation=(1, 1), residual=True, BatchNorm=None):
30
+ super(BasicBlock, self).__init__()
31
+ self.conv1 = conv3x3(inplanes, planes, stride,
32
+ padding=dilation[0], dilation=dilation[0])
33
+ self.bn1 = BatchNorm(planes)
34
+ self.relu = nn.ReLU(inplace=True)
35
+ self.conv2 = conv3x3(planes, planes,
36
+ padding=dilation[1], dilation=dilation[1])
37
+ self.bn2 = BatchNorm(planes)
38
+ self.downsample = downsample
39
+ self.stride = stride
40
+ self.residual = residual
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+ if self.residual:
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
65
+ dilation=(1, 1), residual=True, BatchNorm=None):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68
+ self.bn1 = BatchNorm(planes)
69
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70
+ padding=dilation[1], bias=False,
71
+ dilation=dilation[1])
72
+ self.bn2 = BatchNorm(planes)
73
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
74
+ self.bn3 = BatchNorm(planes * 4)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ residual = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv3(out)
91
+ out = self.bn3(out)
92
+
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+
102
+ class DRN(nn.Module):
103
+
104
+ def __init__(self, block, layers, arch='D',
105
+ channels=(16, 32, 64, 128, 256, 512, 512, 512),
106
+ BatchNorm=None):
107
+ super(DRN, self).__init__()
108
+ self.inplanes = channels[0]
109
+ self.out_dim = channels[-1]
110
+ self.arch = arch
111
+
112
+ if arch == 'C':
113
+ self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
114
+ padding=3, bias=False)
115
+ self.bn1 = BatchNorm(channels[0])
116
+ self.relu = nn.ReLU(inplace=True)
117
+
118
+ self.layer1 = self._make_layer(
119
+ BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
120
+ self.layer2 = self._make_layer(
121
+ BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
122
+
123
+ elif arch == 'D':
124
+ self.layer0 = nn.Sequential(
125
+ nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
126
+ bias=False),
127
+ BatchNorm(channels[0]),
128
+ nn.ReLU(inplace=True)
129
+ )
130
+
131
+ self.layer1 = self._make_conv_layers(
132
+ channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
133
+ self.layer2 = self._make_conv_layers(
134
+ channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
135
+
136
+ self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
137
+ self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
138
+ self.layer5 = self._make_layer(block, channels[4], layers[4],
139
+ dilation=2, new_level=False, BatchNorm=BatchNorm)
140
+ self.layer6 = None if layers[5] == 0 else \
141
+ self._make_layer(block, channels[5], layers[5], dilation=4,
142
+ new_level=False, BatchNorm=BatchNorm)
143
+
144
+ if arch == 'C':
145
+ self.layer7 = None if layers[6] == 0 else \
146
+ self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
147
+ new_level=False, residual=False, BatchNorm=BatchNorm)
148
+ self.layer8 = None if layers[7] == 0 else \
149
+ self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
150
+ new_level=False, residual=False, BatchNorm=BatchNorm)
151
+ elif arch == 'D':
152
+ self.layer7 = None if layers[6] == 0 else \
153
+ self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
154
+ self.layer8 = None if layers[7] == 0 else \
155
+ self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
156
+
157
+ self._init_weight()
158
+
159
+ def _init_weight(self):
160
+ for m in self.modules():
161
+ if isinstance(m, nn.Conv2d):
162
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
163
+ m.weight.data.normal_(0, math.sqrt(2. / n))
164
+ elif isinstance(m, SynchronizedBatchNorm2d):
165
+ m.weight.data.fill_(1)
166
+ m.bias.data.zero_()
167
+ elif isinstance(m, nn.BatchNorm2d):
168
+ m.weight.data.fill_(1)
169
+ m.bias.data.zero_()
170
+
171
+
172
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
173
+ new_level=True, residual=True, BatchNorm=None):
174
+ assert dilation == 1 or dilation % 2 == 0
175
+ downsample = None
176
+ if stride != 1 or self.inplanes != planes * block.expansion:
177
+ downsample = nn.Sequential(
178
+ nn.Conv2d(self.inplanes, planes * block.expansion,
179
+ kernel_size=1, stride=stride, bias=False),
180
+ BatchNorm(planes * block.expansion),
181
+ )
182
+
183
+ layers = list()
184
+ layers.append(block(
185
+ self.inplanes, planes, stride, downsample,
186
+ dilation=(1, 1) if dilation == 1 else (
187
+ dilation // 2 if new_level else dilation, dilation),
188
+ residual=residual, BatchNorm=BatchNorm))
189
+ self.inplanes = planes * block.expansion
190
+ for i in range(1, blocks):
191
+ layers.append(block(self.inplanes, planes, residual=residual,
192
+ dilation=(dilation, dilation), BatchNorm=BatchNorm))
193
+
194
+ return nn.Sequential(*layers)
195
+
196
+ def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
197
+ modules = []
198
+ for i in range(convs):
199
+ modules.extend([
200
+ nn.Conv2d(self.inplanes, channels, kernel_size=3,
201
+ stride=stride if i == 0 else 1,
202
+ padding=dilation, bias=False, dilation=dilation),
203
+ BatchNorm(channels),
204
+ nn.ReLU(inplace=True)])
205
+ self.inplanes = channels
206
+ return nn.Sequential(*modules)
207
+
208
+ def forward(self, x):
209
+ if self.arch == 'C':
210
+ x = self.conv1(x)
211
+ x = self.bn1(x)
212
+ x = self.relu(x)
213
+ elif self.arch == 'D':
214
+ x = self.layer0(x)
215
+
216
+ x = self.layer1(x)
217
+ x = self.layer2(x)
218
+
219
+ x = self.layer3(x)
220
+ low_level_feat = x
221
+
222
+ x = self.layer4(x)
223
+ x = self.layer5(x)
224
+
225
+ if self.layer6 is not None:
226
+ x = self.layer6(x)
227
+
228
+ if self.layer7 is not None:
229
+ x = self.layer7(x)
230
+
231
+ if self.layer8 is not None:
232
+ x = self.layer8(x)
233
+
234
+ return x, low_level_feat
235
+
236
+
237
+ class DRN_A(nn.Module):
238
+
239
+ def __init__(self, block, layers, BatchNorm=None):
240
+ self.inplanes = 64
241
+ super(DRN_A, self).__init__()
242
+ self.out_dim = 512 * block.expansion
243
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
244
+ bias=False)
245
+ self.bn1 = BatchNorm(64)
246
+ self.relu = nn.ReLU(inplace=True)
247
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
248
+ self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
249
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
250
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
251
+ dilation=2, BatchNorm=BatchNorm)
252
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
253
+ dilation=4, BatchNorm=BatchNorm)
254
+
255
+ self._init_weight()
256
+
257
+ def _init_weight(self):
258
+ for m in self.modules():
259
+ if isinstance(m, nn.Conv2d):
260
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
261
+ m.weight.data.normal_(0, math.sqrt(2. / n))
262
+ elif isinstance(m, SynchronizedBatchNorm2d):
263
+ m.weight.data.fill_(1)
264
+ m.bias.data.zero_()
265
+ elif isinstance(m, nn.BatchNorm2d):
266
+ m.weight.data.fill_(1)
267
+ m.bias.data.zero_()
268
+
269
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
270
+ downsample = None
271
+ if stride != 1 or self.inplanes != planes * block.expansion:
272
+ downsample = nn.Sequential(
273
+ nn.Conv2d(self.inplanes, planes * block.expansion,
274
+ kernel_size=1, stride=stride, bias=False),
275
+ BatchNorm(planes * block.expansion),
276
+ )
277
+
278
+ layers = []
279
+ layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
280
+ self.inplanes = planes * block.expansion
281
+ for i in range(1, blocks):
282
+ layers.append(block(self.inplanes, planes,
283
+ dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
284
+
285
+ return nn.Sequential(*layers)
286
+
287
+ def forward(self, x):
288
+ x = self.conv1(x)
289
+ x = self.bn1(x)
290
+ x = self.relu(x)
291
+ x = self.maxpool(x)
292
+
293
+ x = self.layer1(x)
294
+ x = self.layer2(x)
295
+ x = self.layer3(x)
296
+ x = self.layer4(x)
297
+
298
+ return x
299
+
300
+ def drn_a_50(BatchNorm, pretrained=True):
301
+ model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
302
+ if pretrained:
303
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
304
+ return model
305
+
306
+
307
+ def drn_c_26(BatchNorm, pretrained=True):
308
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
309
+ if pretrained:
310
+ pretrained = model_zoo.load_url(model_urls['drn-c-26'])
311
+ del pretrained['fc.weight']
312
+ del pretrained['fc.bias']
313
+ model.load_state_dict(pretrained)
314
+ return model
315
+
316
+
317
+ def drn_c_42(BatchNorm, pretrained=True):
318
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
319
+ if pretrained:
320
+ pretrained = model_zoo.load_url(model_urls['drn-c-42'])
321
+ del pretrained['fc.weight']
322
+ del pretrained['fc.bias']
323
+ model.load_state_dict(pretrained)
324
+ return model
325
+
326
+
327
+ def drn_c_58(BatchNorm, pretrained=True):
328
+ model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
329
+ if pretrained:
330
+ pretrained = model_zoo.load_url(model_urls['drn-c-58'])
331
+ del pretrained['fc.weight']
332
+ del pretrained['fc.bias']
333
+ model.load_state_dict(pretrained)
334
+ return model
335
+
336
+
337
+ def drn_d_22(BatchNorm, pretrained=True):
338
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
339
+ if pretrained:
340
+ pretrained = model_zoo.load_url(model_urls['drn-d-22'])
341
+ del pretrained['fc.weight']
342
+ del pretrained['fc.bias']
343
+ model.load_state_dict(pretrained)
344
+ return model
345
+
346
+
347
+ def drn_d_24(BatchNorm, pretrained=True):
348
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
349
+ if pretrained:
350
+ pretrained = model_zoo.load_url(model_urls['drn-d-24'])
351
+ del pretrained['fc.weight']
352
+ del pretrained['fc.bias']
353
+ model.load_state_dict(pretrained)
354
+ return model
355
+
356
+
357
+ def drn_d_38(BatchNorm, pretrained=True):
358
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
359
+ if pretrained:
360
+ pretrained = model_zoo.load_url(model_urls['drn-d-38'])
361
+ del pretrained['fc.weight']
362
+ del pretrained['fc.bias']
363
+ model.load_state_dict(pretrained)
364
+ return model
365
+
366
+
367
+ def drn_d_40(BatchNorm, pretrained=True):
368
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
369
+ if pretrained:
370
+ pretrained = model_zoo.load_url(model_urls['drn-d-40'])
371
+ del pretrained['fc.weight']
372
+ del pretrained['fc.bias']
373
+ model.load_state_dict(pretrained)
374
+ return model
375
+
376
+
377
+ def drn_d_54(BatchNorm, pretrained=True):
378
+ model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
379
+ if pretrained:
380
+ pretrained = model_zoo.load_url(model_urls['drn-d-54'])
381
+ del pretrained['fc.weight']
382
+ del pretrained['fc.bias']
383
+ model.load_state_dict(pretrained)
384
+ return model
385
+
386
+
387
+ def drn_d_105(BatchNorm, pretrained=True):
388
+ model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
389
+ if pretrained:
390
+ pretrained = model_zoo.load_url(model_urls['drn-d-105'])
391
+ del pretrained['fc.weight']
392
+ del pretrained['fc.bias']
393
+ model.load_state_dict(pretrained)
394
+ return model
395
+
396
+ if __name__ == "__main__":
397
+ import torch
398
+ model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
399
+ input = torch.rand(1, 3, 512, 512)
400
+ output, low_level_feat = model(input)
401
+ print(output.size())
402
+ print(low_level_feat.size())
data/MBD/model/deep_lab_model/backbone/mobilenet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import math
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+ import torch.utils.model_zoo as model_zoo
7
+
8
+ def conv_bn(inp, oup, stride, BatchNorm):
9
+ return nn.Sequential(
10
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11
+ BatchNorm(oup),
12
+ nn.ReLU6(inplace=True)
13
+ )
14
+
15
+
16
+ def fixed_padding(inputs, kernel_size, dilation):
17
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
18
+ pad_total = kernel_size_effective - 1
19
+ pad_beg = pad_total // 2
20
+ pad_end = pad_total - pad_beg
21
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
22
+ return padded_inputs
23
+
24
+
25
+ class InvertedResidual(nn.Module):
26
+ def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
27
+ super(InvertedResidual, self).__init__()
28
+ self.stride = stride
29
+ assert stride in [1, 2]
30
+
31
+ hidden_dim = round(inp * expand_ratio)
32
+ self.use_res_connect = self.stride == 1 and inp == oup
33
+ self.kernel_size = 3
34
+ self.dilation = dilation
35
+
36
+ if expand_ratio == 1:
37
+ self.conv = nn.Sequential(
38
+ # dw
39
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
40
+ BatchNorm(hidden_dim),
41
+ nn.ReLU6(inplace=True),
42
+ # pw-linear
43
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
44
+ BatchNorm(oup),
45
+ )
46
+ else:
47
+ self.conv = nn.Sequential(
48
+ # pw
49
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
50
+ BatchNorm(hidden_dim),
51
+ nn.ReLU6(inplace=True),
52
+ # dw
53
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
54
+ BatchNorm(hidden_dim),
55
+ nn.ReLU6(inplace=True),
56
+ # pw-linear
57
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
58
+ BatchNorm(oup),
59
+ )
60
+
61
+ def forward(self, x):
62
+ x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
63
+ if self.use_res_connect:
64
+ x = x + self.conv(x_pad)
65
+ else:
66
+ x = self.conv(x_pad)
67
+ return x
68
+
69
+
70
+ class MobileNetV2(nn.Module):
71
+ def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
72
+ super(MobileNetV2, self).__init__()
73
+ block = InvertedResidual
74
+ input_channel = 32
75
+ current_stride = 1
76
+ rate = 1
77
+ interverted_residual_setting = [
78
+ # t, c, n, s
79
+ [1, 16, 1, 1],
80
+ [6, 24, 2, 2],
81
+ [6, 32, 3, 2],
82
+ [6, 64, 4, 2],
83
+ [6, 96, 3, 1],
84
+ [6, 160, 3, 2],
85
+ [6, 320, 1, 1],
86
+ ]
87
+
88
+ # building first layer
89
+ input_channel = int(input_channel * width_mult)
90
+ self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
91
+ current_stride *= 2
92
+ # building inverted residual blocks
93
+ for t, c, n, s in interverted_residual_setting:
94
+ if current_stride == output_stride:
95
+ stride = 1
96
+ dilation = rate
97
+ rate *= s
98
+ else:
99
+ stride = s
100
+ dilation = 1
101
+ current_stride *= s
102
+ output_channel = int(c * width_mult)
103
+ for i in range(n):
104
+ if i == 0:
105
+ self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
106
+ else:
107
+ self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
108
+ input_channel = output_channel
109
+ self.features = nn.Sequential(*self.features)
110
+ self._initialize_weights()
111
+
112
+ if pretrained:
113
+ self._load_pretrained_model()
114
+
115
+ self.low_level_features = self.features[0:4]
116
+ self.high_level_features = self.features[4:]
117
+
118
+ def forward(self, x):
119
+ low_level_feat = self.low_level_features(x)
120
+ x = self.high_level_features(low_level_feat)
121
+ return x, low_level_feat
122
+
123
+ def _load_pretrained_model(self):
124
+ pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
125
+ model_dict = {}
126
+ state_dict = self.state_dict()
127
+ for k, v in pretrain_dict.items():
128
+ if k in state_dict:
129
+ model_dict[k] = v
130
+ state_dict.update(model_dict)
131
+ self.load_state_dict(state_dict)
132
+
133
+ def _initialize_weights(self):
134
+ for m in self.modules():
135
+ if isinstance(m, nn.Conv2d):
136
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
138
+ torch.nn.init.kaiming_normal_(m.weight)
139
+ elif isinstance(m, SynchronizedBatchNorm2d):
140
+ m.weight.data.fill_(1)
141
+ m.bias.data.zero_()
142
+ elif isinstance(m, nn.BatchNorm2d):
143
+ m.weight.data.fill_(1)
144
+ m.bias.data.zero_()
145
+
146
+ if __name__ == "__main__":
147
+ input = torch.rand(1, 3, 512, 512)
148
+ model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
149
+ output, low_level_feat = model(input)
150
+ print(output.size())
151
+ print(low_level_feat.size())
data/MBD/model/deep_lab_model/backbone/resnet.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ import torch.utils.model_zoo as model_zoo
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+
6
+ class Bottleneck(nn.Module):
7
+ expansion = 4
8
+
9
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
10
+ super(Bottleneck, self).__init__()
11
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12
+ self.bn1 = BatchNorm(planes)
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14
+ dilation=dilation, padding=dilation, bias=False)
15
+ self.bn2 = BatchNorm(planes)
16
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17
+ self.bn3 = BatchNorm(planes * 4)
18
+ self.relu = nn.ReLU(inplace=True)
19
+ self.downsample = downsample
20
+ self.stride = stride
21
+ self.dilation = dilation
22
+
23
+ def forward(self, x):
24
+ residual = x
25
+
26
+ out = self.conv1(x)
27
+ out = self.bn1(out)
28
+ out = self.relu(out)
29
+
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+ out = self.relu(out)
33
+
34
+ out = self.conv3(out)
35
+ out = self.bn3(out)
36
+
37
+ if self.downsample is not None:
38
+ residual = self.downsample(x)
39
+
40
+ out += residual
41
+ out = self.relu(out)
42
+
43
+ return out
44
+
45
+ class ResNet(nn.Module):
46
+
47
+ def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
48
+ self.inplanes = 64
49
+ super(ResNet, self).__init__()
50
+ blocks = [1, 2, 4]
51
+ if output_stride == 16:
52
+ strides = [1, 2, 2, 1]
53
+ dilations = [1, 1, 1, 2]
54
+ elif output_stride == 8:
55
+ strides = [1, 2, 1, 1]
56
+ dilations = [1, 1, 2, 4]
57
+ else:
58
+ raise NotImplementedError
59
+
60
+ # Modules
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = BatchNorm(64)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66
+
67
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
68
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
69
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
70
+ self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
71
+ # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
72
+ self._init_weight()
73
+
74
+ # if pretrained:
75
+ # self._load_pretrained_model()
76
+
77
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
78
+ downsample = None
79
+ if stride != 1 or self.inplanes != planes * block.expansion:
80
+ downsample = nn.Sequential(
81
+ nn.Conv2d(self.inplanes, planes * block.expansion,
82
+ kernel_size=1, stride=stride, bias=False),
83
+ BatchNorm(planes * block.expansion),
84
+ )
85
+
86
+ layers = []
87
+ layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
88
+ self.inplanes = planes * block.expansion
89
+ for i in range(1, blocks):
90
+ layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
91
+
92
+ return nn.Sequential(*layers)
93
+
94
+ def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
95
+ downsample = None
96
+ if stride != 1 or self.inplanes != planes * block.expansion:
97
+ downsample = nn.Sequential(
98
+ nn.Conv2d(self.inplanes, planes * block.expansion,
99
+ kernel_size=1, stride=stride, bias=False),
100
+ BatchNorm(planes * block.expansion),
101
+ )
102
+
103
+ layers = []
104
+ layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
105
+ downsample=downsample, BatchNorm=BatchNorm))
106
+ self.inplanes = planes * block.expansion
107
+ for i in range(1, len(blocks)):
108
+ layers.append(block(self.inplanes, planes, stride=1,
109
+ dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
110
+
111
+ return nn.Sequential(*layers)
112
+
113
+ def forward(self, input):
114
+ x = self.conv1(input)
115
+ x = self.bn1(x)
116
+ x = self.relu(x)
117
+ x = self.maxpool(x)
118
+
119
+ x = self.layer1(x)
120
+ low_level_feat = x
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.layer4(x)
124
+ return x, low_level_feat
125
+
126
+ def _init_weight(self):
127
+ for m in self.modules():
128
+ if isinstance(m, nn.Conv2d):
129
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+ m.weight.data.normal_(0, math.sqrt(2. / n))
131
+ elif isinstance(m, SynchronizedBatchNorm2d):
132
+ m.weight.data.fill_(1)
133
+ m.bias.data.zero_()
134
+ elif isinstance(m, nn.BatchNorm2d):
135
+ m.weight.data.fill_(1)
136
+ m.bias.data.zero_()
137
+
138
+ def _load_pretrained_model(self):
139
+
140
+ import urllib.request
141
+ import ssl
142
+ ssl._create_default_https_context = ssl._create_unverified_context
143
+ response = urllib.request.urlopen('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
144
+
145
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
146
+ model_dict = {}
147
+ state_dict = self.state_dict()
148
+ for k, v in pretrain_dict.items():
149
+ if k in state_dict:
150
+ # if 'conv1' in k:
151
+ # continue
152
+ model_dict[k] = v
153
+ state_dict.update(model_dict)
154
+ self.load_state_dict(state_dict)
155
+
156
+ def ResNet101(output_stride, BatchNorm, pretrained=True):
157
+ """Constructs a ResNet-101 model.
158
+ Args:
159
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
160
+ """
161
+ model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
162
+ return model
163
+
164
+ if __name__ == "__main__":
165
+ import torch
166
+ model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
167
+ input = torch.rand(1, 3, 512, 512)
168
+ output, low_level_feat = model(input)
169
+ print(output.size())
170
+ print(low_level_feat.size())
data/MBD/model/deep_lab_model/backbone/xception.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7
+
8
+ def fixed_padding(inputs, kernel_size, dilation):
9
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
10
+ pad_total = kernel_size_effective - 1
11
+ pad_beg = pad_total // 2
12
+ pad_end = pad_total - pad_beg
13
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
14
+ return padded_inputs
15
+
16
+
17
+ class SeparableConv2d(nn.Module):
18
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
19
+ super(SeparableConv2d, self).__init__()
20
+
21
+ self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
22
+ groups=inplanes, bias=bias)
23
+ self.bn = BatchNorm(inplanes)
24
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
25
+
26
+ def forward(self, x):
27
+ x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
28
+ x = self.conv1(x)
29
+ x = self.bn(x)
30
+ x = self.pointwise(x)
31
+ return x
32
+
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
36
+ start_with_relu=True, grow_first=True, is_last=False):
37
+ super(Block, self).__init__()
38
+
39
+ if planes != inplanes or stride != 1:
40
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
41
+ self.skipbn = BatchNorm(planes)
42
+ else:
43
+ self.skip = None
44
+
45
+ self.relu = nn.ReLU(inplace=True)
46
+ rep = []
47
+
48
+ filters = inplanes
49
+ if grow_first:
50
+ rep.append(self.relu)
51
+ rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
52
+ rep.append(BatchNorm(planes))
53
+ filters = planes
54
+
55
+ for i in range(reps - 1):
56
+ rep.append(self.relu)
57
+ rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
58
+ rep.append(BatchNorm(filters))
59
+
60
+ if not grow_first:
61
+ rep.append(self.relu)
62
+ rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
63
+ rep.append(BatchNorm(planes))
64
+
65
+ if stride != 1:
66
+ rep.append(self.relu)
67
+ rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
68
+ rep.append(BatchNorm(planes))
69
+
70
+ if stride == 1 and is_last:
71
+ rep.append(self.relu)
72
+ rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
73
+ rep.append(BatchNorm(planes))
74
+
75
+ if not start_with_relu:
76
+ rep = rep[1:]
77
+
78
+ self.rep = nn.Sequential(*rep)
79
+
80
+ def forward(self, inp):
81
+ x = self.rep(inp)
82
+
83
+ if self.skip is not None:
84
+ skip = self.skip(inp)
85
+ skip = self.skipbn(skip)
86
+ else:
87
+ skip = inp
88
+
89
+ x = x + skip
90
+
91
+ return x
92
+
93
+
94
+ class AlignedXception(nn.Module):
95
+ """
96
+ Modified Alighed Xception
97
+ """
98
+ def __init__(self, output_stride, BatchNorm,
99
+ pretrained=True):
100
+ super(AlignedXception, self).__init__()
101
+
102
+ if output_stride == 16:
103
+ entry_block3_stride = 2
104
+ middle_block_dilation = 1
105
+ exit_block_dilations = (1, 2)
106
+ elif output_stride == 8:
107
+ entry_block3_stride = 1
108
+ middle_block_dilation = 2
109
+ exit_block_dilations = (2, 4)
110
+ else:
111
+ raise NotImplementedError
112
+
113
+
114
+ # Entry flow
115
+ self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
116
+ self.bn1 = BatchNorm(32)
117
+ self.relu = nn.ReLU(inplace=True)
118
+
119
+ self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
120
+ self.bn2 = BatchNorm(64)
121
+
122
+ self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
123
+ self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
124
+ grow_first=True)
125
+ self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
126
+ start_with_relu=True, grow_first=True, is_last=True)
127
+
128
+ # Middle flow
129
+ self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
130
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
131
+ self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
132
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
133
+ self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
134
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
135
+ self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
136
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
137
+ self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
138
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
139
+ self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
140
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
141
+ self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
142
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
143
+ self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
144
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
145
+ self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
146
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
147
+ self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
148
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
149
+ self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
150
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
151
+ self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
152
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
153
+ self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
154
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
155
+ self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
156
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
157
+ self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
158
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
159
+ self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
160
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
161
+
162
+ # Exit flow
163
+ self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
164
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
165
+
166
+ self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
167
+ self.bn3 = BatchNorm(1536)
168
+
169
+ self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
170
+ self.bn4 = BatchNorm(1536)
171
+
172
+ self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
173
+ self.bn5 = BatchNorm(2048)
174
+
175
+ # Init weights
176
+ self._init_weight()
177
+
178
+ # Load pretrained model
179
+ if pretrained:
180
+ self._load_pretrained_model()
181
+
182
+ def forward(self, x):
183
+ # Entry flow
184
+ x = self.conv1(x)
185
+ x = self.bn1(x)
186
+ x = self.relu(x)
187
+
188
+ x = self.conv2(x)
189
+ x = self.bn2(x)
190
+ x = self.relu(x)
191
+
192
+ x = self.block1(x)
193
+ # add relu here
194
+ x = self.relu(x)
195
+ low_level_feat = x
196
+ x = self.block2(x)
197
+ x = self.block3(x)
198
+
199
+ # Middle flow
200
+ x = self.block4(x)
201
+ x = self.block5(x)
202
+ x = self.block6(x)
203
+ x = self.block7(x)
204
+ x = self.block8(x)
205
+ x = self.block9(x)
206
+ x = self.block10(x)
207
+ x = self.block11(x)
208
+ x = self.block12(x)
209
+ x = self.block13(x)
210
+ x = self.block14(x)
211
+ x = self.block15(x)
212
+ x = self.block16(x)
213
+ x = self.block17(x)
214
+ x = self.block18(x)
215
+ x = self.block19(x)
216
+
217
+ # Exit flow
218
+ x = self.block20(x)
219
+ x = self.relu(x)
220
+ x = self.conv3(x)
221
+ x = self.bn3(x)
222
+ x = self.relu(x)
223
+
224
+ x = self.conv4(x)
225
+ x = self.bn4(x)
226
+ x = self.relu(x)
227
+
228
+ x = self.conv5(x)
229
+ x = self.bn5(x)
230
+ x = self.relu(x)
231
+
232
+ return x, low_level_feat
233
+
234
+ def _init_weight(self):
235
+ for m in self.modules():
236
+ if isinstance(m, nn.Conv2d):
237
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
238
+ m.weight.data.normal_(0, math.sqrt(2. / n))
239
+ elif isinstance(m, SynchronizedBatchNorm2d):
240
+ m.weight.data.fill_(1)
241
+ m.bias.data.zero_()
242
+ elif isinstance(m, nn.BatchNorm2d):
243
+ m.weight.data.fill_(1)
244
+ m.bias.data.zero_()
245
+
246
+
247
+ def _load_pretrained_model(self):
248
+ pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
249
+ model_dict = {}
250
+ state_dict = self.state_dict()
251
+
252
+ for k, v in pretrain_dict.items():
253
+ if k in state_dict:
254
+ if 'pointwise' in k:
255
+ v = v.unsqueeze(-1).unsqueeze(-1)
256
+ if k.startswith('block11'):
257
+ model_dict[k] = v
258
+ model_dict[k.replace('block11', 'block12')] = v
259
+ model_dict[k.replace('block11', 'block13')] = v
260
+ model_dict[k.replace('block11', 'block14')] = v
261
+ model_dict[k.replace('block11', 'block15')] = v
262
+ model_dict[k.replace('block11', 'block16')] = v
263
+ model_dict[k.replace('block11', 'block17')] = v
264
+ model_dict[k.replace('block11', 'block18')] = v
265
+ model_dict[k.replace('block11', 'block19')] = v
266
+ elif k.startswith('block12'):
267
+ model_dict[k.replace('block12', 'block20')] = v
268
+ elif k.startswith('bn3'):
269
+ model_dict[k] = v
270
+ model_dict[k.replace('bn3', 'bn4')] = v
271
+ elif k.startswith('conv4'):
272
+ model_dict[k.replace('conv4', 'conv5')] = v
273
+ elif k.startswith('bn4'):
274
+ model_dict[k.replace('bn4', 'bn5')] = v
275
+ else:
276
+ model_dict[k] = v
277
+ state_dict.update(model_dict)
278
+ self.load_state_dict(state_dict)
279
+
280
+
281
+
282
+ if __name__ == "__main__":
283
+ import torch
284
+ model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
285
+ input = torch.rand(1, 3, 512, 512)
286
+ output, low_level_feat = model(input)
287
+ print(output.size())
288
+ print(low_level_feat.size())
data/MBD/model/deep_lab_model/decoder.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+
7
+ class Decoder(nn.Module):
8
+ def __init__(self, num_classes, backbone, BatchNorm):
9
+ super(Decoder, self).__init__()
10
+ if backbone == 'resnet' or backbone == 'drn':
11
+ low_level_inplanes = 256
12
+ elif backbone == 'xception':
13
+ low_level_inplanes = 128
14
+ elif backbone == 'mobilenet':
15
+ low_level_inplanes = 24
16
+ else:
17
+ raise NotImplementedError
18
+
19
+ self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20
+ self.bn1 = BatchNorm(48)
21
+ self.relu = nn.ReLU()
22
+ self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23
+ BatchNorm(256),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27
+ BatchNorm(256),
28
+ nn.ReLU(),
29
+ nn.Dropout(0.1),
30
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
31
+ nn.Sigmoid()
32
+ )
33
+ self._init_weight()
34
+
35
+
36
+ def forward(self, x, low_level_feat):
37
+ low_level_feat = self.conv1(low_level_feat)
38
+ low_level_feat = self.bn1(low_level_feat)
39
+ low_level_feat = self.relu(low_level_feat)
40
+
41
+ x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
42
+ x = torch.cat((x, low_level_feat), dim=1)
43
+ x = self.last_conv(x)
44
+
45
+ return x
46
+
47
+ def _init_weight(self):
48
+ for m in self.modules():
49
+ if isinstance(m, nn.Conv2d):
50
+ torch.nn.init.kaiming_normal_(m.weight)
51
+ elif isinstance(m, SynchronizedBatchNorm2d):
52
+ m.weight.data.fill_(1)
53
+ m.bias.data.zero_()
54
+ elif isinstance(m, nn.BatchNorm2d):
55
+ m.weight.data.fill_(1)
56
+ m.bias.data.zero_()
57
+
58
+ def build_decoder(num_classes, backbone, BatchNorm):
59
+ return Decoder(num_classes, backbone, BatchNorm)
data/MBD/model/deep_lab_model/deeplab.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+ from model.deep_lab_model.aspp import build_aspp
6
+ from model.deep_lab_model.decoder import build_decoder
7
+ from model.deep_lab_model.backbone import build_backbone
8
+
9
+ class DeepLab(nn.Module):
10
+ def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
11
+ sync_bn=True, freeze_bn=False):
12
+ super(DeepLab, self).__init__()
13
+ if backbone == 'drn':
14
+ output_stride = 8
15
+
16
+ if sync_bn == True:
17
+ BatchNorm = SynchronizedBatchNorm2d
18
+ else:
19
+ BatchNorm = nn.BatchNorm2d
20
+
21
+ self.backbone = build_backbone(backbone, output_stride, BatchNorm)
22
+ self.aspp = build_aspp(backbone, output_stride, BatchNorm)
23
+ self.decoder = build_decoder(num_classes, backbone, BatchNorm)
24
+
25
+ self.freeze_bn = freeze_bn
26
+
27
+ def forward(self, input):
28
+ x, low_level_feat = self.backbone(input)
29
+ x = self.aspp(x)
30
+ x = self.decoder(x, low_level_feat)
31
+ x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
32
+
33
+ return x
34
+
35
+ def freeze_bn(self):
36
+ for m in self.modules():
37
+ if isinstance(m, SynchronizedBatchNorm2d):
38
+ m.eval()
39
+ elif isinstance(m, nn.BatchNorm2d):
40
+ m.eval()
41
+
42
+ def get_1x_lr_params(self):
43
+ modules = [self.backbone]
44
+ for i in range(len(modules)):
45
+ for m in modules[i].named_modules():
46
+ if self.freeze_bn:
47
+ if isinstance(m[1], nn.Conv2d):
48
+ for p in m[1].parameters():
49
+ if p.requires_grad:
50
+ yield p
51
+ else:
52
+ if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
53
+ or isinstance(m[1], nn.BatchNorm2d):
54
+ for p in m[1].parameters():
55
+ if p.requires_grad:
56
+ yield p
57
+
58
+ def get_10x_lr_params(self):
59
+ modules = [self.aspp, self.decoder]
60
+ for i in range(len(modules)):
61
+ for m in modules[i].named_modules():
62
+ if self.freeze_bn:
63
+ if isinstance(m[1], nn.Conv2d):
64
+ for p in m[1].parameters():
65
+ if p.requires_grad:
66
+ yield p
67
+ else:
68
+ if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
69
+ or isinstance(m[1], nn.BatchNorm2d):
70
+ for p in m[1].parameters():
71
+ if p.requires_grad:
72
+ yield p
73
+
74
+ if __name__ == "__main__":
75
+ model = DeepLab(backbone='mobilenet', output_stride=16)
76
+ model.eval()
77
+ input = torch.rand(1, 3, 513, 513)
78
+ output = model(input)
79
+ print(output.size())
80
+
81
+
data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ return F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+
55
+ # Resize the input to (B, C, -1).
56
+ input_shape = input.size()
57
+ input = input.view(input.size(0), self.num_features, -1)
58
+
59
+ # Compute the sum and square-sum.
60
+ sum_size = input.size(0) * input.size(2)
61
+ input_sum = _sum_ft(input)
62
+ input_ssum = _sum_ft(input ** 2)
63
+
64
+ # Reduce-and-broadcast the statistics.
65
+ if self._parallel_id == 0:
66
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67
+ else:
68
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69
+
70
+ # Compute the output.
71
+ if self.affine:
72
+ # MJY:: Fuse the multiplication for speed.
73
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74
+ else:
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76
+
77
+ # Reshape it.
78
+ return output.view(input_shape)
79
+
80
+ def __data_parallel_replicate__(self, ctx, copy_id):
81
+ self._is_parallel = True
82
+ self._parallel_id = copy_id
83
+
84
+ # parallel_id == 0 means master device.
85
+ if self._parallel_id == 0:
86
+ ctx.sync_master = self._sync_master
87
+ else:
88
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89
+
90
+ def _data_parallel_master(self, intermediates):
91
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92
+
93
+ # Always using same "device order" makes the ReduceAdd operation faster.
94
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96
+
97
+ to_reduce = [i[1][:2] for i in intermediates]
98
+ to_reduce = [j for i in to_reduce for j in i] # flatten
99
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
100
+
101
+ sum_size = sum([i[1].sum_size for i in intermediates])
102
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104
+
105
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106
+
107
+ outputs = []
108
+ for i, rec in enumerate(intermediates):
109
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110
+
111
+ return outputs
112
+
113
+ def _compute_mean_std(self, sum_, ssum, size):
114
+ """Compute the mean and standard-deviation with sum and square-sum. This method
115
+ also maintains the moving average on the master device."""
116
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117
+ mean = sum_ / size
118
+ sumvar = ssum - sum_ * mean
119
+ unbias_var = sumvar / (size - 1)
120
+ bias_var = sumvar / size
121
+
122
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124
+
125
+ return mean, bias_var.clamp(self.eps) ** -0.5
126
+
127
+
128
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130
+ mini-batch.
131
+ .. math::
132
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
134
+ standard-deviation are reduced across all devices during training.
135
+ For example, when one uses `nn.DataParallel` to wrap the network during
136
+ training, PyTorch's implementation normalize the tensor on each device using
137
+ the statistics only on that device, which accelerated the computation and
138
+ is also easy to implement, but the statistics might be inaccurate.
139
+ Instead, in this synchronized version, the statistics will be computed
140
+ over all training samples distributed on multiple devices.
141
+
142
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
143
+ as the built-in PyTorch implementation.
144
+ The mean and standard-deviation are calculated per-dimension over
145
+ the mini-batches and gamma and beta are learnable parameter vectors
146
+ of size C (where C is the input size).
147
+ During training, this layer keeps a running estimate of its computed mean
148
+ and variance. The running sum is kept with a default momentum of 0.1.
149
+ During evaluation, this running mean/variance is used for normalization.
150
+ Because the BatchNorm is done over the `C` dimension, computing statistics
151
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152
+ Args:
153
+ num_features: num_features from an expected input of size
154
+ `batch_size x num_features [x width]`
155
+ eps: a value added to the denominator for numerical stability.
156
+ Default: 1e-5
157
+ momentum: the value used for the running_mean and running_var
158
+ computation. Default: 0.1
159
+ affine: a boolean value that when set to ``True``, gives the layer learnable
160
+ affine parameters. Default: ``True``
161
+ Shape:
162
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
163
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164
+ Examples:
165
+ >>> # With Learnable Parameters
166
+ >>> m = SynchronizedBatchNorm1d(100)
167
+ >>> # Without Learnable Parameters
168
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
169
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
170
+ >>> output = m(input)
171
+ """
172
+
173
+ def _check_input_dim(self, input):
174
+ if input.dim() != 2 and input.dim() != 3:
175
+ raise ValueError('expected 2D or 3D input (got {}D input)'
176
+ .format(input.dim()))
177
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178
+
179
+
180
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182
+ of 3d inputs
183
+ .. math::
184
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
186
+ standard-deviation are reduced across all devices during training.
187
+ For example, when one uses `nn.DataParallel` to wrap the network during
188
+ training, PyTorch's implementation normalize the tensor on each device using
189
+ the statistics only on that device, which accelerated the computation and
190
+ is also easy to implement, but the statistics might be inaccurate.
191
+ Instead, in this synchronized version, the statistics will be computed
192
+ over all training samples distributed on multiple devices.
193
+
194
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
195
+ as the built-in PyTorch implementation.
196
+ The mean and standard-deviation are calculated per-dimension over
197
+ the mini-batches and gamma and beta are learnable parameter vectors
198
+ of size C (where C is the input size).
199
+ During training, this layer keeps a running estimate of its computed mean
200
+ and variance. The running sum is kept with a default momentum of 0.1.
201
+ During evaluation, this running mean/variance is used for normalization.
202
+ Because the BatchNorm is done over the `C` dimension, computing statistics
203
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204
+ Args:
205
+ num_features: num_features from an expected input of
206
+ size batch_size x num_features x height x width
207
+ eps: a value added to the denominator for numerical stability.
208
+ Default: 1e-5
209
+ momentum: the value used for the running_mean and running_var
210
+ computation. Default: 0.1
211
+ affine: a boolean value that when set to ``True``, gives the layer learnable
212
+ affine parameters. Default: ``True``
213
+ Shape:
214
+ - Input: :math:`(N, C, H, W)`
215
+ - Output: :math:`(N, C, H, W)` (same shape as input)
216
+ Examples:
217
+ >>> # With Learnable Parameters
218
+ >>> m = SynchronizedBatchNorm2d(100)
219
+ >>> # Without Learnable Parameters
220
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
221
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222
+ >>> output = m(input)
223
+ """
224
+
225
+ def _check_input_dim(self, input):
226
+ if input.dim() != 4:
227
+ raise ValueError('expected 4D input (got {}D input)'
228
+ .format(input.dim()))
229
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230
+
231
+
232
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234
+ of 4d inputs
235
+ .. math::
236
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
238
+ standard-deviation are reduced across all devices during training.
239
+ For example, when one uses `nn.DataParallel` to wrap the network during
240
+ training, PyTorch's implementation normalize the tensor on each device using
241
+ the statistics only on that device, which accelerated the computation and
242
+ is also easy to implement, but the statistics might be inaccurate.
243
+ Instead, in this synchronized version, the statistics will be computed
244
+ over all training samples distributed on multiple devices.
245
+
246
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
247
+ as the built-in PyTorch implementation.
248
+ The mean and standard-deviation are calculated per-dimension over
249
+ the mini-batches and gamma and beta are learnable parameter vectors
250
+ of size C (where C is the input size).
251
+ During training, this layer keeps a running estimate of its computed mean
252
+ and variance. The running sum is kept with a default momentum of 0.1.
253
+ During evaluation, this running mean/variance is used for normalization.
254
+ Because the BatchNorm is done over the `C` dimension, computing statistics
255
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256
+ or Spatio-temporal BatchNorm
257
+ Args:
258
+ num_features: num_features from an expected input of
259
+ size batch_size x num_features x depth x height x width
260
+ eps: a value added to the denominator for numerical stability.
261
+ Default: 1e-5
262
+ momentum: the value used for the running_mean and running_var
263
+ computation. Default: 0.1
264
+ affine: a boolean value that when set to ``True``, gives the layer learnable
265
+ affine parameters. Default: ``True``
266
+ Shape:
267
+ - Input: :math:`(N, C, D, H, W)`
268
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
269
+ Examples:
270
+ >>> # With Learnable Parameters
271
+ >>> m = SynchronizedBatchNorm3d(100)
272
+ >>> # Without Learnable Parameters
273
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
274
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275
+ >>> output = m(input)
276
+ """
277
+
278
+ def _check_input_dim(self, input):
279
+ if input.dim() != 5:
280
+ raise ValueError('expected 5D input (got {}D input)'
281
+ .format(input.dim()))
282
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
data/MBD/model/deep_lab_model/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61
+ and passed to a registered callback.
62
+ - After receiving the messages, the master device should gather the information and determine to message passed
63
+ back to each slave devices.
64
+ """
65
+
66
+ def __init__(self, master_callback):
67
+ """
68
+ Args:
69
+ master_callback: a callback to be invoked after having collected messages from slave devices.
70
+ """
71
+ self._master_callback = master_callback
72
+ self._queue = queue.Queue()
73
+ self._registry = collections.OrderedDict()
74
+ self._activated = False
75
+
76
+ def __getstate__(self):
77
+ return {'master_callback': self._master_callback}
78
+
79
+ def __setstate__(self, state):
80
+ self.__init__(state['master_callback'])
81
+
82
+ def register_slave(self, identifier):
83
+ """
84
+ Register an slave device.
85
+ Args:
86
+ identifier: an identifier, usually is the device id.
87
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
88
+ """
89
+ if self._activated:
90
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
91
+ self._activated = False
92
+ self._registry.clear()
93
+ future = FutureResult()
94
+ self._registry[identifier] = _MasterRegistry(future)
95
+ return SlavePipe(identifier, self._queue, future)
96
+
97
+ def run_master(self, master_msg):
98
+ """
99
+ Main entry for the master device in each forward pass.
100
+ The messages were first collected from each devices (including the master device), and then
101
+ an callback will be invoked to compute the message to be sent back to each devices
102
+ (including the master device).
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+ Returns: the message to be sent back to the master device.
107
+ """
108
+ self._activated = True
109
+
110
+ intermediates = [(0, master_msg)]
111
+ for i in range(self.nr_slaves):
112
+ intermediates.append(self._queue.get())
113
+
114
+ results = self._master_callback(intermediates)
115
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
116
+
117
+ for i, res in results:
118
+ if i == 0:
119
+ continue
120
+ self._registry[i].result.put(res)
121
+
122
+ for i in range(self.nr_slaves):
123
+ assert self._queue.get() is True
124
+
125
+ return results[0][1]
126
+
127
+ @property
128
+ def nr_slaves(self):
129
+ return len(self._registry)
data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
32
+ (shared among multiple copies of this module on different devices).
33
+ Through this context, different copies can share some information.
34
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35
+ of any slave copies.
36
+ """
37
+ master_copy = modules[0]
38
+ nr_modules = len(list(master_copy.modules()))
39
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
40
+
41
+ for i, module in enumerate(modules):
42
+ for j, m in enumerate(module.modules()):
43
+ if hasattr(m, '__data_parallel_replicate__'):
44
+ m.__data_parallel_replicate__(ctxs[j], i)
45
+
46
+
47
+ class DataParallelWithCallback(DataParallel):
48
+ """
49
+ Data Parallel with a replication callback.
50
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51
+ original `replicate` function.
52
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53
+ Examples:
54
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56
+ # sync_bn.__data_parallel_replicate__ will be invoked.
57
+ """
58
+
59
+ def replicate(self, module, device_ids):
60
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61
+ execute_replication_callbacks(modules)
62
+ return modules
63
+
64
+
65
+ def patch_replication_callback(data_parallel):
66
+ """
67
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
68
+ Useful when you have customized `DataParallel` implementation.
69
+ Examples:
70
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72
+ > patch_replication_callback(sync_bn)
73
+ # this is equivalent to
74
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76
+ """
77
+
78
+ assert isinstance(data_parallel, DataParallel)
79
+
80
+ old_replicate = data_parallel.replicate
81
+
82
+ @functools.wraps(old_replicate)
83
+ def new_replicate(module, device_ids):
84
+ modules = old_replicate(module, device_ids)
85
+ execute_replication_callbacks(modules)
86
+ return modules
87
+
88
+ data_parallel.replicate = new_replicate
data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
data/MBD/model/densenetccnl.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Densenet decoder encoder with intermediate fully connected layers and dropout
2
+
3
+ import torch
4
+ import torch.backends.cudnn as cudnn
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import functools
8
+ from torch.autograd import gradcheck
9
+ from torch.autograd import Function
10
+ from torch.autograd import Variable
11
+ from torch.autograd import gradcheck
12
+ from torch.autograd import Function
13
+ import numpy as np
14
+
15
+
16
+ def add_coordConv_channels(t):
17
+ n,c,h,w=t.size()
18
+ xx_channel=np.ones((h, w))
19
+ xx_range=np.array(range(h))
20
+ xx_range=np.expand_dims(xx_range,-1)
21
+ xx_coord=xx_channel*xx_range
22
+ yy_coord=xx_coord.transpose()
23
+
24
+ xx_coord=xx_coord/(h-1)
25
+ yy_coord=yy_coord/(h-1)
26
+ xx_coord=xx_coord*2 - 1
27
+ yy_coord=yy_coord*2 - 1
28
+ xx_coord=torch.from_numpy(xx_coord).float()
29
+ yy_coord=torch.from_numpy(yy_coord).float()
30
+
31
+ if t.is_cuda:
32
+ xx_coord=xx_coord.cuda()
33
+ yy_coord=yy_coord.cuda()
34
+
35
+ xx_coord=xx_coord.unsqueeze(0).unsqueeze(0).repeat(n,1,1,1)
36
+ yy_coord=yy_coord.unsqueeze(0).unsqueeze(0).repeat(n,1,1,1)
37
+
38
+ t_cc=torch.cat((t,xx_coord,yy_coord),dim=1)
39
+
40
+ return t_cc
41
+
42
+
43
+
44
+ class DenseBlockEncoder(nn.Module):
45
+ def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]):
46
+ super(DenseBlockEncoder, self).__init__()
47
+ assert(n_convs > 0)
48
+
49
+ self.n_channels = n_channels
50
+ self.n_convs = n_convs
51
+ self.layers = nn.ModuleList()
52
+ for i in range(n_convs):
53
+ self.layers.append(nn.Sequential(
54
+ nn.BatchNorm2d(n_channels),
55
+ activation(*args),
56
+ nn.Conv2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False),))
57
+
58
+ def forward(self, inputs):
59
+ outputs = []
60
+
61
+ for i, layer in enumerate(self.layers):
62
+ if i > 0:
63
+ next_output = 0
64
+ for no in outputs:
65
+ next_output = next_output + no
66
+ outputs.append(next_output)
67
+ else:
68
+ outputs.append(layer(inputs))
69
+ return outputs[-1]
70
+
71
+ # Dense block in encoder.
72
+ class DenseBlockDecoder(nn.Module):
73
+ def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]):
74
+ super(DenseBlockDecoder, self).__init__()
75
+ assert(n_convs > 0)
76
+
77
+ self.n_channels = n_channels
78
+ self.n_convs = n_convs
79
+ self.layers = nn.ModuleList()
80
+ for i in range(n_convs):
81
+ self.layers.append(nn.Sequential(
82
+ nn.BatchNorm2d(n_channels),
83
+ activation(*args),
84
+ nn.ConvTranspose2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False),))
85
+
86
+ def forward(self, inputs):
87
+ outputs = []
88
+
89
+ for i, layer in enumerate(self.layers):
90
+ if i > 0:
91
+ next_output = 0
92
+ for no in outputs:
93
+ next_output = next_output + no
94
+ outputs.append(next_output)
95
+ else:
96
+ outputs.append(layer(inputs))
97
+ return outputs[-1]
98
+
99
+ class DenseTransitionBlockEncoder(nn.Module):
100
+ def __init__(self, n_channels_in, n_channels_out, mp, activation=nn.ReLU, args=[False]):
101
+ super(DenseTransitionBlockEncoder, self).__init__()
102
+ self.n_channels_in = n_channels_in
103
+ self.n_channels_out = n_channels_out
104
+ self.mp = mp
105
+ self.main = nn.Sequential(
106
+ nn.BatchNorm2d(n_channels_in),
107
+ activation(*args),
108
+ nn.Conv2d(n_channels_in, n_channels_out, 1, stride=1, padding=0, bias=False),
109
+ nn.MaxPool2d(mp),
110
+ )
111
+ def forward(self, inputs):
112
+ # print(inputs.shape,'222222222222222',self.main(inputs).shape)
113
+ return self.main(inputs)
114
+
115
+
116
+ class DenseTransitionBlockDecoder(nn.Module):
117
+ def __init__(self, n_channels_in, n_channels_out, activation=nn.ReLU, args=[False]):
118
+ super(DenseTransitionBlockDecoder, self).__init__()
119
+ self.n_channels_in = n_channels_in
120
+ self.n_channels_out = n_channels_out
121
+ self.main = nn.Sequential(
122
+ nn.BatchNorm2d(n_channels_in),
123
+ activation(*args),
124
+ nn.ConvTranspose2d(n_channels_in, n_channels_out, 4, stride=2, padding=1, bias=False),
125
+ )
126
+ def forward(self, inputs):
127
+ # print(inputs.shape,'333333333333',self.main(inputs).shape)
128
+ return self.main(inputs)
129
+
130
+ ## Dense encoders and decoders for image of size 128 128
131
+ class waspDenseEncoder128(nn.Module):
132
+ def __init__(self, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, f_args=[]):
133
+ super(waspDenseEncoder128, self).__init__()
134
+ self.ndim = ndim
135
+
136
+ self.main = nn.Sequential(
137
+ # input is (nc) x 128 x 128
138
+ nn.BatchNorm2d(nc),
139
+ nn.ReLU(True),
140
+ nn.Conv2d(nc, ndf, 4, stride=2, padding=1),
141
+
142
+ # state size. (ndf) x 64 x 64
143
+ DenseBlockEncoder(ndf, 6),
144
+ DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args),
145
+
146
+ # state size. (ndf*2) x 32 x 32
147
+ DenseBlockEncoder(ndf*2, 12),
148
+ DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args),
149
+
150
+ # state size. (ndf*4) x 16 x 16
151
+ DenseBlockEncoder(ndf*4, 16),
152
+ DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args),
153
+
154
+ # state size. (ndf*4) x 8 x 8
155
+ DenseBlockEncoder(ndf*8, 16),
156
+ DenseTransitionBlockEncoder(ndf*8, ndf*8, 2, activation=activation, args=args),
157
+
158
+ # state size. (ndf*8) x 4 x 4
159
+ DenseBlockEncoder(ndf*8, 16),
160
+ DenseTransitionBlockEncoder(ndf*8, ndim, 4, activation=activation, args=args),
161
+ f_activation(*f_args),
162
+ )
163
+
164
+ def forward(self, input):
165
+ input=add_coordConv_channels(input)
166
+ output = self.main(input).view(-1,self.ndim)
167
+ #print(output.size())
168
+ return output
169
+
170
+ class waspDenseDecoder128(nn.Module):
171
+ def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh, f_args=[]):
172
+ super(waspDenseDecoder128, self).__init__()
173
+ self.main = nn.Sequential(
174
+ # input is Z, going into convolution
175
+ nn.BatchNorm2d(nz),
176
+ activation(*args),
177
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
178
+
179
+ # state size. (ngf*8) x 4 x 4
180
+ DenseBlockDecoder(ngf*8, 16),
181
+ DenseTransitionBlockDecoder(ngf*8, ngf*8),
182
+
183
+ # state size. (ngf*4) x 8 x 8
184
+ DenseBlockDecoder(ngf*8, 16),
185
+ DenseTransitionBlockDecoder(ngf*8, ngf*4),
186
+
187
+ # state size. (ngf*2) x 16 x 16
188
+ DenseBlockDecoder(ngf*4, 12),
189
+ DenseTransitionBlockDecoder(ngf*4, ngf*2),
190
+
191
+ # state size. (ngf) x 32 x 32
192
+ DenseBlockDecoder(ngf*2, 6),
193
+ DenseTransitionBlockDecoder(ngf*2, ngf),
194
+
195
+ # state size. (ngf) x 64 x 64
196
+ DenseBlockDecoder(ngf, 6),
197
+ DenseTransitionBlockDecoder(ngf, ngf),
198
+
199
+ # state size (ngf) x 128 x 128
200
+ nn.BatchNorm2d(ngf),
201
+ activation(*args),
202
+ nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
203
+ f_activation(*f_args),
204
+ )
205
+ # self.smooth=nn.Sequential(
206
+ # nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False),
207
+ # f_activation(*f_args),
208
+ # )
209
+ def forward(self, inputs):
210
+ # return self.smooth(self.main(inputs))
211
+ return self.main(inputs)
212
+
213
+
214
+
215
+ ## Dense encoders and decoders for image of size 512 512
216
+ class waspDenseEncoder512(nn.Module):
217
+ def __init__(self, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, f_args=[]):
218
+ super(waspDenseEncoder512, self).__init__()
219
+ self.ndim = ndim
220
+
221
+ self.main = nn.Sequential(
222
+ # input is (nc) x 128 x 128 > *4
223
+ nn.BatchNorm2d(nc),
224
+ nn.ReLU(True),
225
+ nn.Conv2d(nc, ndf, 4, stride=2, padding=1),
226
+
227
+ # state size. (ndf) x 64 x 64 > *4
228
+ DenseBlockEncoder(ndf, 6),
229
+ DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args),
230
+
231
+ # state size. (ndf*2) x 32 x 32 > *4
232
+ DenseBlockEncoder(ndf*2, 12),
233
+ DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args),
234
+
235
+ # state size. (ndf*4) x 16 x 16 > *4
236
+ DenseBlockEncoder(ndf*4, 16),
237
+ DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args),
238
+
239
+ # state size. (ndf*8) x 8 x 8 *4
240
+ DenseBlockEncoder(ndf*8, 16),
241
+ DenseTransitionBlockEncoder(ndf*8, ndf*8, 2, activation=activation, args=args),
242
+
243
+ # state size. (ndf*8) x 4 x 4 > *4
244
+ DenseBlockEncoder(ndf*8, 16),
245
+ DenseTransitionBlockEncoder(ndf*8, ndf*8, 4, activation=activation, args=args),
246
+ f_activation(*f_args),
247
+
248
+ # state size. (ndf*8) x 2 x 2 > *4
249
+ DenseBlockEncoder(ndf*8, 16),
250
+ DenseTransitionBlockEncoder(ndf*8, ndim, 4, activation=activation, args=args),
251
+ f_activation(*f_args),
252
+ )
253
+
254
+ def forward(self, input):
255
+ input=add_coordConv_channels(input)
256
+ output = self.main(input).view(-1,self.ndim)
257
+ # output = self.main(input).view(8,-1)
258
+ # print(input.shape,'---------------------')
259
+ #print(output.size())
260
+ return output
261
+
262
+ class waspDenseDecoder512(nn.Module):
263
+ def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Tanh, f_args=[]):
264
+ super(waspDenseDecoder512, self).__init__()
265
+ self.main = nn.Sequential(
266
+ # input is Z, going into convolution
267
+ nn.BatchNorm2d(nz),
268
+ activation(*args),
269
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
270
+
271
+ # state size. (ngf*8) x 4 x 4
272
+ DenseBlockDecoder(ngf*8, 16),
273
+ DenseTransitionBlockDecoder(ngf*8, ngf*8),
274
+
275
+ # state size. (ngf*8) x 8 x 8
276
+ DenseBlockDecoder(ngf*8, 16),
277
+ DenseTransitionBlockDecoder(ngf*8, ngf*8),
278
+
279
+ # state size. (ngf*4) x 16 x 16
280
+ DenseBlockDecoder(ngf*8, 16),
281
+ DenseTransitionBlockDecoder(ngf*8, ngf*4),
282
+
283
+ # state size. (ngf*2) x 32 x 32
284
+ DenseBlockDecoder(ngf*4, 12),
285
+ DenseTransitionBlockDecoder(ngf*4, ngf*2),
286
+
287
+ # state size. (ngf) x 64 x 64
288
+ DenseBlockDecoder(ngf*2, 6),
289
+ DenseTransitionBlockDecoder(ngf*2, ngf),
290
+
291
+ # state size. (ngf) x 128 x 128
292
+ DenseBlockDecoder(ngf, 6),
293
+ DenseTransitionBlockDecoder(ngf, ngf),
294
+
295
+ # state size. (ngf) x 256 x 256
296
+ DenseBlockDecoder(ngf, 6),
297
+ DenseTransitionBlockDecoder(ngf, ngf),
298
+
299
+ # state size (ngf) x 512 x 512
300
+ nn.BatchNorm2d(ngf),
301
+ activation(*args),
302
+ nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
303
+ f_activation(*f_args),
304
+ )
305
+ # self.smooth=nn.Sequential(
306
+ # nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False),
307
+ # f_activation(*f_args),
308
+ # )
309
+ def forward(self, inputs):
310
+ # return self.smooth(self.main(inputs))
311
+ return self.main(inputs)
312
+
313
+
314
+ class dnetccnl(nn.Module):
315
+ #in_channels -> nc | encoder first layer
316
+ #filters -> ndf | encoder first layer
317
+ #img_size(h,w) -> ndim
318
+ #out_channels -> optical flow (x,y)
319
+
320
+ def __init__(self, img_size=448, in_channels=3, out_channels=2, filters=32,fc_units=100):
321
+ super(dnetccnl, self).__init__()
322
+ self.nc=in_channels
323
+ self.nf=filters
324
+ self.ndim=img_size
325
+ self.oc=out_channels
326
+ self.fcu=fc_units
327
+
328
+ self.encoder=waspDenseEncoder128(nc=self.nc+2,ndf=self.nf,ndim=self.ndim)
329
+ self.decoder=waspDenseDecoder128(nz=self.ndim,nc=self.oc,ngf=self.nf)
330
+ # self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu),
331
+ # nn.ReLU(True),
332
+ # nn.Dropout(0.25),
333
+ # nn.Linear(self.fcu,self.ndim),
334
+ # nn.ReLU(True),
335
+ # nn.Dropout(0.25),
336
+ # )
337
+
338
+ def forward(self, inputs):
339
+
340
+ encoded=self.encoder(inputs)
341
+ encoded=encoded.unsqueeze(-1).unsqueeze(-1)
342
+ decoded=self.decoder(encoded)
343
+ # print torch.max(decoded)
344
+ # print torch.min(decoded)
345
+ # print(decoded.shape,'11111111111111111',encoded.shape)
346
+
347
+ return decoded
348
+
349
+ class dnetccnl512(nn.Module):
350
+ #in_channels -> nc | encoder first layer
351
+ #filters -> ndf | encoder first layer
352
+ #img_size(h,w) -> ndim
353
+ #out_channels -> optical flow (x,y)
354
+
355
+ def __init__(self, img_size=448, in_channels=3, out_channels=2, filters=32,fc_units=100):
356
+ super(dnetccnl512, self).__init__()
357
+ self.nc=in_channels
358
+ self.nf=filters
359
+ self.ndim=img_size
360
+ self.oc=out_channels
361
+ self.fcu=fc_units
362
+
363
+ self.encoder=waspDenseEncoder512(nc=self.nc+2,ndf=self.nf,ndim=self.ndim)
364
+ self.decoder=waspDenseDecoder512(nz=self.ndim,nc=self.oc,ngf=self.nf)
365
+ # self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu),
366
+ # nn.ReLU(True),
367
+ # nn.Dropout(0.25),
368
+ # nn.Linear(self.fcu,self.ndim),
369
+ # nn.ReLU(True),
370
+ # nn.Dropout(0.25),
371
+ # )
372
+
373
+ def forward(self, inputs):
374
+
375
+ encoded=self.encoder(inputs)
376
+ encoded=encoded.unsqueeze(-1).unsqueeze(-1)
377
+ decoded=self.decoder(encoded)
378
+ # print torch.max(decoded)
379
+ # print torch.min(decoded)
380
+ # print(decoded.shape,'11111111111111111',encoded.shape)
381
+
382
+ return decoded
data/MBD/model/gienet.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import log
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import init
5
+ import functools
6
+ from model.cbam import CBAM
7
+ # Defines the Unet generator.
8
+ # |num_downs|: number of downsamplings in UNet. For example,
9
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
10
+ # at the bottleneck
11
+ class SingleConv(nn.Module):
12
+ """(convolution => [BN] => ReLU) * 2"""
13
+
14
+ def __init__(self, in_channels, out_channels):
15
+ super().__init__()
16
+ self.double_conv = nn.Sequential(
17
+ nn.ReflectionPad2d(1),
18
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
19
+ nn.BatchNorm2d(out_channels),
20
+ nn.ReLU(inplace=True),
21
+ # nn.ReflectionPad2d(1),
22
+ # nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
23
+ # nn.BatchNorm2d(out_channels),
24
+ # nn.ReLU(inplace=True)
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.double_conv(x)
29
+ class Down_single(nn.Module):
30
+ """Downscaling with maxpool then double conv"""
31
+
32
+ def __init__(self, in_channels, out_channels):
33
+ super().__init__()
34
+ self.maxpool_conv = nn.Sequential(
35
+ nn.MaxPool2d(2),
36
+ SingleConv(in_channels, out_channels)
37
+ )
38
+
39
+ def forward(self, x):
40
+ return self.maxpool_conv(x)
41
+ class Up_single(nn.Module):
42
+ """Upscaling then double conv"""
43
+ def __init__(self, in_channels, out_channels, bilinear=True):
44
+ super().__init__()
45
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
46
+ self.conv = SingleConv(in_channels, out_channels)
47
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
48
+ def forward(self, x1, x2):
49
+ x1 = self.deconv(x1)
50
+ # input is BCHW
51
+ x = torch.cat([x2, x1], dim=1)
52
+ return self.conv(x)
53
+ class DoubleConv(nn.Module):
54
+ """(convolution => [BN] => ReLU) * 2"""
55
+
56
+ def __init__(self, in_channels, out_channels):
57
+ super().__init__()
58
+ self.double_conv = nn.Sequential(
59
+ nn.ReflectionPad2d(1),
60
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
61
+ nn.BatchNorm2d(out_channels),
62
+ nn.ReLU(inplace=True),
63
+ nn.ReflectionPad2d(1),
64
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
65
+ nn.BatchNorm2d(out_channels),
66
+ nn.ReLU(inplace=True)
67
+ )
68
+
69
+ def forward(self, x):
70
+ return self.double_conv(x)
71
+ class Down(nn.Module):
72
+ """Downscaling with maxpool then double conv"""
73
+
74
+ def __init__(self, in_channels, out_channels):
75
+ super().__init__()
76
+ self.maxpool_conv = nn.Sequential(
77
+ nn.MaxPool2d(2),
78
+ DoubleConv(in_channels, out_channels)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.maxpool_conv(x)
83
+ class Up(nn.Module):
84
+ """Upscaling then double conv"""
85
+ def __init__(self, in_channels, out_channels, bilinear=True):
86
+ super().__init__()
87
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
88
+ self.conv = DoubleConv(in_channels, out_channels)
89
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
90
+ def forward(self, x1, x2):
91
+ x1 = self.deconv(x1)
92
+ # input is BCHW
93
+ x = torch.cat([x2, x1], dim=1)
94
+ return self.conv(x)
95
+
96
+ class OutConv(nn.Module):
97
+ def __init__(self, in_channels, out_channels):
98
+ super(OutConv, self).__init__()
99
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
100
+ self.tanh = nn.Tanh()
101
+ self.hardtanh = nn.Hardtanh()
102
+ self.sigmoid = nn.Sigmoid()
103
+
104
+ def forward(self, x1):
105
+ x = self.conv(x1)
106
+ # x = self.sigmoid(x)
107
+ # x = self.hardtanh(x)
108
+ # x = (x+1)/2
109
+ return x
110
+ class GiemaskGenerator(nn.Module):
111
+ """Create a Unet-based generator"""
112
+
113
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
114
+ """Construct a Unet generator
115
+ Parameters:
116
+ input_nc (int) -- the number of channels in input images
117
+ output_nc (int) -- the number of channels in output images
118
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
119
+ image of size 128x128 will become of size 1x1 # at the bottleneck
120
+ ngf (int) -- the number of filters in the last conv layer
121
+ norm_layer -- normalization layer
122
+
123
+ We construct the U-Net from the innermost layer to the outermost layer.
124
+ It is a recursive process.
125
+ """
126
+ super(GiemaskGenerator, self).__init__()
127
+ self.init_channel =32
128
+ self.inc = DoubleConv(3,self.init_channel)
129
+ self.down1 = Down(self.init_channel, self.init_channel*2)
130
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
131
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
132
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
133
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
134
+
135
+ self.up1 = Up(self.init_channel*32, self.init_channel*16)
136
+ self.up2 = Up(self.init_channel*16, self.init_channel*8)
137
+ self.up3 = Up(self.init_channel*8, self.init_channel*4)
138
+ self.up4 = Up(self.init_channel*4,self.init_channel*2)
139
+ self.up5 = Up(self.init_channel*2, self.init_channel)
140
+ self.outc = OutConv(self.init_channel, 1)
141
+ self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
142
+ self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
143
+ self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
144
+ self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
145
+ self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
146
+ self.outc_1 = OutConv(self.init_channel, 1)
147
+ # self.dropout = nn.Dropout(p=0.5)
148
+ def forward(self, input):
149
+ x1 = self.inc(input)
150
+ x2 = self.down1(x1)
151
+ x3 = self.down2(x2)
152
+ x4 = self.down3(x3)
153
+ x5 = self.down4(x4)
154
+ x6 = self.down5(x5)
155
+
156
+
157
+ x_1 = self.up1_1(x6, x5)
158
+ x_1 = self.up2_1(x_1, x4)
159
+ x_1 = self.up3_1(x_1, x3)
160
+ x_1 = self.up4_1(x_1, x2)
161
+ x_1 = self.up5_1(x_1, x1)
162
+ mask = self.outc_1(x_1)
163
+
164
+ x = self.up1(x6, x5)
165
+ # x = self.dropout(x)
166
+ x = self.up2(x, x4)
167
+ # x = self.dropout(x)
168
+ x = self.up3(x, x3)
169
+ # x = self.dropout(x)
170
+ x = self.up4(x, x2)
171
+ # x = self.dropout(x)
172
+ x = self.up5(x, x1)
173
+ # x = self.dropout(x)
174
+ depth = self.outc(x)
175
+ return depth,mask
176
+ """Create a Unet-based generator"""
177
+ class Giemask2Generator(nn.Module):
178
+ """Create a Unet-based generator"""
179
+
180
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
181
+ """Construct a Unet generator
182
+ Parameters:
183
+ input_nc (int) -- the number of channels in input images
184
+ output_nc (int) -- the number of channels in output images
185
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
186
+ image of size 128x128 will become of size 1x1 # at the bottleneck
187
+ ngf (int) -- the number of filters in the last conv layer
188
+ norm_layer -- normalization layer
189
+
190
+ We construct the U-Net from the innermost layer to the outermost layer.
191
+ It is a recursive process.
192
+ """
193
+ super(Giemask2Generator, self).__init__()
194
+ self.init_channel =32
195
+ self.inc = DoubleConv(3,self.init_channel)
196
+ self.down1 = Down(self.init_channel, self.init_channel*2)
197
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
198
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
199
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
200
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
201
+
202
+ self.up1 = Up(self.init_channel*32, self.init_channel*16)
203
+ self.up2 = Up(self.init_channel*16, self.init_channel*8)
204
+ self.up3 = Up(self.init_channel*8, self.init_channel*4)
205
+ self.up4 = Up(self.init_channel*4,self.init_channel*2)
206
+ self.up5 = Up(self.init_channel*2, self.init_channel)
207
+ self.outc = OutConv(self.init_channel, 1)
208
+ self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
209
+ self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
210
+ self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
211
+ self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
212
+ self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
213
+ self.outc_1 = OutConv(self.init_channel, 1)
214
+ self.outc_2 = OutConv(self.init_channel, 1)
215
+ # self.dropout = nn.Dropout(p=0.5)
216
+ def forward(self, input):
217
+ x1 = self.inc(input)
218
+ x2 = self.down1(x1)
219
+ x3 = self.down2(x2)
220
+ x4 = self.down3(x3)
221
+ x5 = self.down4(x4)
222
+ x6 = self.down5(x5)
223
+
224
+
225
+ x_1 = self.up1_1(x6, x5)
226
+ x_1 = self.up2_1(x_1, x4)
227
+ x_1 = self.up3_1(x_1, x3)
228
+ x_1 = self.up4_1(x_1, x2)
229
+ x_1 = self.up5_1(x_1, x1)
230
+ mask = self.outc_1(x_1)
231
+ edge = self.outc_2(x_1)
232
+
233
+ x = self.up1(x6, x5)
234
+ # x = self.dropout(x)
235
+ x = self.up2(x, x4)
236
+ # x = self.dropout(x)
237
+ x = self.up3(x, x3)
238
+ # x = self.dropout(x)
239
+ x = self.up4(x, x2)
240
+ # x = self.dropout(x)
241
+ x = self.up5(x, x1)
242
+ # x = self.dropout(x)
243
+ depth = self.outc(x)
244
+ return depth,mask,edge
245
+ """Create a Unet-based generator"""
246
+ class GieGenerator(nn.Module):
247
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
248
+ """Construct a Unet generator
249
+ Parameters:
250
+ input_nc (int) -- the number of channels in input images
251
+ output_nc (int) -- the number of channels in output images
252
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
253
+ image of size 128x128 will become of size 1x1 # at the bottleneck
254
+ ngf (int) -- the number of filters in the last conv layer
255
+ norm_layer -- normalization layer
256
+
257
+ We construct the U-Net from the innermost layer to the outermost layer.
258
+ It is a recursive process.
259
+ """
260
+ super(GieGenerator, self).__init__()
261
+ self.init_channel =32
262
+ self.inc = DoubleConv(input_nc,self.init_channel)
263
+ self.down1 = Down(self.init_channel, self.init_channel*2)
264
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
265
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
266
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
267
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
268
+
269
+ self.up1 = Up(self.init_channel*32, self.init_channel*16)
270
+ self.up2 = Up(self.init_channel*16, self.init_channel*8)
271
+ self.up3 = Up(self.init_channel*8, self.init_channel*4)
272
+ self.up4 = Up(self.init_channel*4,self.init_channel*2)
273
+ self.up5 = Up(self.init_channel*2, self.init_channel)
274
+ self.outc = OutConv(self.init_channel, 2)
275
+ # self.dropout = nn.Dropout(p=0.5)
276
+ def forward(self, input):
277
+ x1 = self.inc(input)
278
+ x2 = self.down1(x1)
279
+ x3 = self.down2(x2)
280
+ x4 = self.down3(x3)
281
+ x5 = self.down4(x4)
282
+ x6 = self.down5(x5)
283
+
284
+ x = self.up1(x6, x5)
285
+ # x = self.dropout(x)
286
+ x = self.up2(x, x4)
287
+ # x = self.dropout(x)
288
+ x = self.up3(x, x3)
289
+ # x = self.dropout(x)
290
+ x = self.up4(x, x2)
291
+ # x = self.dropout(x)
292
+ x = self.up5(x, x1)
293
+ # x = self.dropout(x)
294
+ logits1 = self.outc(x)
295
+ return logits1
296
+
297
+
298
+ class GiecbamGenerator(nn.Module):
299
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
300
+ """Construct a Unet generator
301
+ Parameters:
302
+ input_nc (int) -- the number of channels in input images
303
+ output_nc (int) -- the number of channels in output images
304
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
305
+ image of size 128x128 will become of size 1x1 # at the bottleneck
306
+ ngf (int) -- the number of filters in the last conv layer
307
+ norm_layer -- normalization layer
308
+
309
+ We construct the U-Net from the innermost layer to the outermost layer.
310
+ It is a recursive process.
311
+ """
312
+ super(GiecbamGenerator, self).__init__()
313
+ self.init_channel =32
314
+ self.inc = DoubleConv(input_nc,self.init_channel)
315
+ self.down1 = Down(self.init_channel, self.init_channel*2)
316
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
317
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
318
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
319
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
320
+ self.cbam = CBAM(gate_channels=self.init_channel*32)
321
+ self.up1 = Up(self.init_channel*32, self.init_channel*16)
322
+ self.up2 = Up(self.init_channel*16, self.init_channel*8)
323
+ self.up3 = Up(self.init_channel*8, self.init_channel*4)
324
+ self.up4 = Up(self.init_channel*4,self.init_channel*2)
325
+ self.up5 = Up(self.init_channel*2, self.init_channel)
326
+ self.outc = OutConv(self.init_channel, 2)
327
+ self.dropout = nn.Dropout(p=0.1)
328
+ def forward(self, input):
329
+ x1 = self.inc(input)
330
+ x2 = self.down1(x1)
331
+ x3 = self.down2(x2)
332
+ x4 = self.down3(x3)
333
+ x5 = self.down4(x4)
334
+ x6 = self.down5(x5)
335
+ x6 = self.cbam(x6)
336
+ x = self.up1(x6, x5)
337
+ x = self.up2(x, x4)
338
+ x = self.up3(x, x3)
339
+ x = self.up4(x, x2)
340
+ x = self.up5(x, x1)
341
+ x = self.dropout(x)
342
+ logits1 = self.outc(x)
343
+ return logits1
344
+
345
+
346
+
347
+
348
+ class Gie2headGenerator(nn.Module):
349
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
350
+ """Construct a Unet generator
351
+ Parameters:
352
+ input_nc (int) -- the number of channels in input images
353
+ output_nc (int) -- the number of channels in output images
354
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
355
+ image of size 128x128 will become of size 1x1 # at the bottleneck
356
+ ngf (int) -- the number of filters in the last conv layer
357
+ norm_layer -- normalization layer
358
+
359
+ We construct the U-Net from the innermost layer to the outermost layer.
360
+ It is a recursive process.
361
+ """
362
+ super(Gie2headGenerator, self).__init__()
363
+ self.init_channel =32
364
+ self.inc = DoubleConv(input_nc,self.init_channel)
365
+ self.down1 = Down(self.init_channel, self.init_channel*2)
366
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
367
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
368
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
369
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
370
+
371
+ self.up1_1 = Up(self.init_channel*32, self.init_channel*16)
372
+ self.up2_1 = Up(self.init_channel*16, self.init_channel*8)
373
+ self.up3_1 = Up(self.init_channel*8, self.init_channel*4)
374
+ self.up4_1 = Up(self.init_channel*4,self.init_channel*2)
375
+ self.up5_1 = Up(self.init_channel*2, self.init_channel)
376
+ self.outc_1 = OutConv(self.init_channel, 1)
377
+
378
+ self.up1_2 = Up(self.init_channel*32, self.init_channel*16)
379
+ self.up2_2 = Up(self.init_channel*16, self.init_channel*8)
380
+ self.up3_2 = Up(self.init_channel*8, self.init_channel*4)
381
+ self.up4_2 = Up(self.init_channel*4,self.init_channel*2)
382
+ self.up5_2 = Up(self.init_channel*2, self.init_channel)
383
+ self.outc_2 = OutConv(self.init_channel, 1)
384
+
385
+ def forward(self, input):
386
+ x1 = self.inc(input)
387
+ x2 = self.down1(x1)
388
+ x3 = self.down2(x2)
389
+ x4 = self.down3(x3)
390
+ x5 = self.down4(x4)
391
+ x6 = self.down5(x5)
392
+
393
+ x_1 = self.up1_1(x6, x5)
394
+ x_1 = self.up2_1(x_1, x4)
395
+ x_1 = self.up3_1(x_1, x3)
396
+ x_1 = self.up4_1(x_1, x2)
397
+ x_1 = self.up5_1(x_1, x1)
398
+ logits_1 = self.outc_1(x_1)
399
+
400
+ x_2 = self.up1_2(x6, x5)
401
+ x_2 = self.up2_2(x_2, x4)
402
+ x_2 = self.up3_2(x_2, x3)
403
+ x_2 = self.up4_2(x_2, x2)
404
+ x_2 = self.up5_2(x_2, x1)
405
+ logits_2 = self.outc_2(x_2)
406
+
407
+ logits = torch.cat((logits_1,logits_2),1)
408
+
409
+ return logits
410
+
411
+
412
+
413
+ class BmpGenerator(nn.Module):
414
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
415
+ """Construct a Unet generator
416
+ Parameters:
417
+ input_nc (int) -- the number of channels in input images
418
+ output_nc (int) -- the number of channels in output images
419
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
420
+ image of size 128x128 will become of size 1x1 # at the bottleneck
421
+ ngf (int) -- the number of filters in the last conv layer
422
+ norm_layer -- normalization layer
423
+
424
+ We construct the U-Net from the innermost layer to the outermost layer.
425
+ It is a recursive process.
426
+ """
427
+ super(BmpGenerator, self).__init__()
428
+ self.init_channel =32
429
+ self.output_nc = output_nc
430
+ self.inc = DoubleConv(input_nc,self.init_channel)
431
+ self.down1 = Down(self.init_channel, self.init_channel*2)
432
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
433
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
434
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
435
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
436
+
437
+ self.up1 = Up(self.init_channel*32, self.init_channel*16)
438
+ self.up2 = Up(self.init_channel*16, self.init_channel*8)
439
+ self.up3 = Up(self.init_channel*8, self.init_channel*4)
440
+ self.up4 = Up(self.init_channel*4,self.init_channel*2)
441
+ self.up5 = Up(self.init_channel*2, self.init_channel)
442
+ self.outc = OutConv(self.init_channel, self.output_nc)
443
+ # self.dropout = nn.Dropout(p=0.5)
444
+ def forward(self, input):
445
+ x1 = self.inc(input)
446
+ x2 = self.down1(x1)
447
+ x3 = self.down2(x2)
448
+ x4 = self.down3(x3)
449
+ x5 = self.down4(x4)
450
+ x6 = self.down5(x5)
451
+
452
+ x = self.up1(x6, x5)
453
+ # x = self.dropout(x)
454
+ x = self.up2(x, x4)
455
+ # x = self.dropout(x)
456
+ x = self.up3(x, x3)
457
+ # x = self.dropout(x)
458
+ x = self.up4(x, x2)
459
+ # x = self.dropout(x)
460
+ x = self.up5(x, x1)
461
+ # x = self.dropout(x)
462
+ logits1 = self.outc(x)
463
+ return logits1
464
+ class Bmp2Generator(nn.Module):
465
+ """Create a Unet-based generator"""
466
+
467
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
468
+ """Construct a Unet generator
469
+ Parameters:
470
+ input_nc (int) -- the number of channels in input images
471
+ output_nc (int) -- the number of channels in output images
472
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
473
+ image of size 128x128 will become of size 1x1 # at the bottleneck
474
+ ngf (int) -- the number of filters in the last conv layer
475
+ norm_layer -- normalization layer
476
+
477
+ We construct the U-Net from the innermost layer to the outermost layer.
478
+ It is a recursive process.
479
+ """
480
+ super(Bmp2Generator, self).__init__()
481
+ #gienet
482
+ self.init_channel =32
483
+ self.inc = DoubleConv(3,self.init_channel)
484
+ self.down1 = Down(self.init_channel, self.init_channel*2)
485
+ self.down2 = Down(self.init_channel*2, self.init_channel*4)
486
+ self.down3 = Down(self.init_channel*4, self.init_channel*8)
487
+ self.down4 = Down(self.init_channel*8, self.init_channel*16)
488
+ self.down5 = Down(self.init_channel*16, self.init_channel*32)
489
+
490
+ self.up1 = Up(self.init_channel*32, self.init_channel*16)
491
+ self.up2 = Up(self.init_channel*16, self.init_channel*8)
492
+ self.up3 = Up(self.init_channel*8, self.init_channel*4)
493
+ self.up4 = Up(self.init_channel*4,self.init_channel*2)
494
+ self.up5 = Up(self.init_channel*2, self.init_channel)
495
+ self.outc = OutConv(self.init_channel, 1)
496
+ self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
497
+ self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
498
+ self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
499
+ self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
500
+ self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
501
+ self.outc_1 = OutConv(self.init_channel, 1)
502
+ self.outc_2 = OutConv(self.init_channel, 1)
503
+
504
+ #bpm net
505
+ self.inc_b = DoubleConv(4,self.init_channel)
506
+ self.down1_b = Down(self.init_channel, self.init_channel*2)
507
+ self.down2_b = Down(self.init_channel*2, self.init_channel*4)
508
+ self.down3_b = Down(self.init_channel*4, self.init_channel*8)
509
+ self.down4_b = Down(self.init_channel*8, self.init_channel*16)
510
+ self.down5_b = Down(self.init_channel*16, self.init_channel*32)
511
+
512
+ self.up1_b = Up(self.init_channel*32, self.init_channel*16)
513
+ self.up2_b = Up(self.init_channel*16, self.init_channel*8)
514
+ self.up3_b = Up(self.init_channel*8, self.init_channel*4)
515
+ self.up4_b = Up(self.init_channel*4,self.init_channel*2)
516
+ self.up5_b = Up(self.init_channel*2, self.init_channel)
517
+ self.outc_b = OutConv(self.init_channel, 2)
518
+ # self.dropout = nn.Dropout(p=0.5)
519
+ def forward(self, input):
520
+ #gienet
521
+ x1 = self.inc(input)
522
+ x2 = self.down1(x1)
523
+ x3 = self.down2(x2)
524
+ x4 = self.down3(x3)
525
+ x5 = self.down4(x4)
526
+ x6 = self.down5(x5)
527
+
528
+ x_1 = self.up1_1(x6, x5)
529
+ x_1 = self.up2_1(x_1, x4)
530
+ x_1 = self.up3_1(x_1, x3)
531
+ x_1 = self.up4_1(x_1, x2)
532
+ x_1 = self.up5_1(x_1, x1)
533
+ mask = self.outc_1(x_1)
534
+ edge = self.outc_2(x_1)
535
+
536
+ x = self.up1(x6, x5)
537
+ x = self.up2(x, x4)
538
+ x = self.up3(x, x3)
539
+ x = self.up4(x, x2)
540
+ x = self.up5(x, x1)
541
+ depth = self.outc(x)
542
+
543
+ #bmpnet
544
+ mask[mask>0.5]=1.
545
+ mask[mask<=0.5]=0.
546
+ image_cat_depth = torch.cat((input*mask,depth*mask),dim=1)
547
+ x1_b = self.inc_b(image_cat_depth)
548
+ x2_b = self.down1_b(x1_b)
549
+ x3_b = self.down2_b(x2_b)
550
+ x4_b = self.down3_b(x3_b)
551
+ x5_b = self.down4_b(x4_b)
552
+ x6_b = self.down5_b(x5_b)
553
+ x_b = self.up1_b(x6_b, x5_b)
554
+ x_b = self.up2_b(x_b, x4_b)
555
+ x_b = self.up3_b(x_b, x3_b)
556
+ x_b = self.up4_b(x_b, x2_b)
557
+ x_b = self.up5_b(x_b, x1_b)
558
+ bm = self.outc_b(x_b)
559
+ # return depth,mask,edge,bm
560
+ return bm
561
+ class UnetGenerator(nn.Module):
562
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
563
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
564
+ super(UnetGenerator, self).__init__()
565
+
566
+ # construct unet structure
567
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
568
+ for i in range(num_downs - 5):
569
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
570
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
571
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
572
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
573
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
574
+
575
+ self.model = unet_block
576
+
577
+ def forward(self, input):
578
+ return self.model(input)
579
+
580
+ #class GieGenerator(nn.Module):
581
+ # def __init__(self, input_nc, output_nc, num_downs, ngf=64,
582
+ # norm_layer=nn.BatchNorm2d, use_dropout=False):
583
+ # super(GieGenerator, self).__init__()
584
+ #
585
+ # # construct unet structure
586
+ # unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
587
+ # for i in range(num_downs - 5):
588
+ # unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
589
+ # unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
590
+ # unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
591
+ # unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
592
+ # unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
593
+ #
594
+ # self.model = unet_block
595
+ #
596
+ # def forward(self, input):
597
+ # return self.model(input)
598
+
599
+ # Defines the submodule with skip connection.
600
+ # X -------------------identity---------------------- X
601
+ # |-- downsampling -- |submodule| -- upsampling --|
602
+ class UnetSkipConnectionBlock(nn.Module):
603
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
604
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
605
+ super(UnetSkipConnectionBlock, self).__init__()
606
+ self.outermost = outermost
607
+ if type(norm_layer) == functools.partial:
608
+ use_bias = norm_layer.func == nn.InstanceNorm2d
609
+ else:
610
+ use_bias = norm_layer == nn.InstanceNorm2d
611
+ if input_nc is None:
612
+ input_nc = outer_nc
613
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
614
+ stride=2, padding=1, bias=use_bias)
615
+ downrelu = nn.LeakyReLU(0.2, True)
616
+ downnorm = norm_layer(inner_nc)
617
+ uprelu = nn.ReLU(True)
618
+ upnorm = norm_layer(outer_nc)
619
+
620
+ if outermost:
621
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
622
+ kernel_size=4, stride=2,
623
+ padding=1)
624
+ down = [downconv]
625
+ up = [uprelu, upconv, nn.Tanh()]
626
+ model = down + [submodule] + up
627
+ elif innermost:
628
+ # resize = nn.Upsample(scale_factor=2)
629
+ # conv = nn.Conv2d(inner_nc,outer_nc,kernel_size=4,stride=2,padding=1,bias=use_bias)
630
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
631
+ kernel_size=4, stride=2,
632
+ padding=1, bias=use_bias)
633
+ down = [downrelu, downconv]
634
+ up = [uprelu, upconv, upnorm]
635
+ #up = [uprelu, resize, conv, upnorm]
636
+ model = down + up
637
+ else:
638
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
639
+ kernel_size=4, stride=2,
640
+ padding=1, bias=use_bias)
641
+ down = [downrelu, downconv, downnorm]
642
+ up = [uprelu, upconv, upnorm]
643
+
644
+ if use_dropout:
645
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
646
+ else:
647
+ model = down + [submodule] + up
648
+
649
+ self.model = nn.Sequential(*model)
650
+
651
+ def forward(self, x):
652
+ if self.outermost:
653
+ return self.model(x)
654
+ else:
655
+ return torch.cat([x, self.model(x)], 1)
656
+
657
+
658
+
659
+ ##===================================================================================================
660
+ class DilatedDoubleConv(nn.Module):
661
+ """(convolution => [BN] => ReLU) * 2"""
662
+
663
+ def __init__(self, in_channels, out_channels):
664
+ super().__init__()
665
+ self.double_conv = nn.Sequential(
666
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
667
+ nn.BatchNorm2d(out_channels),
668
+ nn.ReLU(inplace=True),
669
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
670
+ nn.BatchNorm2d(out_channels),
671
+ nn.ReLU(inplace=True)
672
+ )
673
+
674
+ def forward(self, x):
675
+ return self.double_conv(x)
676
+
677
+ class DilatedDown(nn.Module):
678
+ """Downscaling with maxpool then double conv"""
679
+
680
+ def __init__(self, in_channels, out_channels):
681
+ super().__init__()
682
+ self.maxpool_conv = nn.Sequential(
683
+ nn.MaxPool2d(2),
684
+ DilatedDoubleConv(in_channels, out_channels)
685
+ )
686
+
687
+ def forward(self, x):
688
+ return self.maxpool_conv(x)
689
+
690
+ class DilatedUp(nn.Module):
691
+ """Upscaling then double conv"""
692
+ def __init__(self, in_channels, out_channels, bilinear=True):
693
+ super().__init__()
694
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
695
+ self.conv = DilatedDoubleConv(in_channels, out_channels)
696
+
697
+ self.conv1 = nn.Sequential(
698
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
699
+ nn.BatchNorm2d(out_channels),
700
+ nn.ReLU(inplace=True),
701
+ )
702
+ # self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
703
+ def forward(self, x1, x2):
704
+ x1 = self.up(x1)
705
+ x1 = self.conv1(x1)
706
+ # x1 = self.deconv(x1)
707
+ # input is BCHW
708
+ x = torch.cat([x2, x1], dim=1)
709
+ return self.conv(x)
710
+ class DilatedSingleUnet(nn.Module):
711
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
712
+ super(DilatedSingleUnet, self).__init__()
713
+ self.init_channel = 32
714
+ self.inc = DilatedDoubleConv(input_nc,self.init_channel)
715
+ self.down1 = DilatedDown(self.init_channel, self.init_channel*2)
716
+ self.down2 = DilatedDown(self.init_channel*2, self.init_channel*4)
717
+ self.down3 = DilatedDown(self.init_channel*4, self.init_channel*8)
718
+ self.down4 = DilatedDown(self.init_channel*8, self.init_channel*16)
719
+ self.down5 = DilatedDown(self.init_channel*16, self.init_channel*32)
720
+ self.cbam = CBAM(gate_channels=self.init_channel*32)
721
+
722
+ self.up1 = DilatedUp(self.init_channel*32, self.init_channel*16)
723
+ self.up2 = DilatedUp(self.init_channel*16, self.init_channel*8)
724
+ self.up3 = DilatedUp(self.init_channel*8, self.init_channel*4)
725
+ self.up4 = DilatedUp(self.init_channel*4,self.init_channel*2)
726
+ self.up5 = DilatedUp(self.init_channel*2, self.init_channel)
727
+ self.outc = OutConv(self.init_channel, output_nc)
728
+ def forward(self, input):
729
+ x1 = self.inc(input)
730
+ x2 = self.down1(x1)
731
+ x3 = self.down2(x2)
732
+ x4 = self.down3(x3)
733
+ x5 = self.down4(x4)
734
+ x6 = self.down5(x5)
735
+ x6 = self.cbam(x6)
736
+ x = self.up1(x6, x5)
737
+ x = self.up2(x, x4)
738
+ x = self.up3(x, x3)
739
+ x = self.up4(x, x2)
740
+ x = self.up5(x, x1)
741
+ logits1 = self.outc(x)
742
+ return logits1
data/MBD/model/unetnc.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+
6
+ # Defines the Unet generator.
7
+ # |num_downs|: number of downsamplings in UNet. For example,
8
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
9
+ # at the bottleneck
10
+ class UnetGenerator(nn.Module):
11
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
12
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
13
+ super(UnetGenerator, self).__init__()
14
+
15
+ # construct unet structure
16
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
17
+ for i in range(num_downs - 5):
18
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
19
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
20
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
21
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
22
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
23
+
24
+ self.model = unet_block
25
+
26
+ def forward(self, input):
27
+ return self.model(input)
28
+
29
+
30
+ def forward(self, input):
31
+ return self.model(input)
32
+
33
+ # Defines the submodule with skip connection.
34
+ # X -------------------identity---------------------- X
35
+ # |-- downsampling -- |submodule| -- upsampling --|
36
+ class UnetSkipConnectionBlock(nn.Module):
37
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
38
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
39
+ super(UnetSkipConnectionBlock, self).__init__()
40
+ self.outermost = outermost
41
+ if type(norm_layer) == functools.partial:
42
+ use_bias = norm_layer.func == nn.InstanceNorm2d
43
+ else:
44
+ use_bias = norm_layer == nn.InstanceNorm2d
45
+ if input_nc is None:
46
+ input_nc = outer_nc
47
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
48
+ stride=2, padding=1, bias=use_bias)
49
+ downrelu = nn.LeakyReLU(0.2, True)
50
+ downnorm = norm_layer(inner_nc)
51
+ uprelu = nn.ReLU(True)
52
+ upnorm = norm_layer(outer_nc)
53
+
54
+ if outermost:
55
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
56
+ kernel_size=4, stride=2,
57
+ padding=1)
58
+ down = [downconv]
59
+ up = [uprelu, upconv, nn.Tanh()]
60
+ model = down + [submodule] + up
61
+ elif innermost:
62
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
63
+ kernel_size=4, stride=2,
64
+ padding=1, bias=use_bias)
65
+ down = [downrelu, downconv]
66
+ up = [uprelu, upconv, upnorm]
67
+ model = down + up
68
+ else:
69
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
70
+ kernel_size=4, stride=2,
71
+ padding=1, bias=use_bias)
72
+ down = [downrelu, downconv, downnorm]
73
+ up = [uprelu, upconv, upnorm]
74
+
75
+ if use_dropout:
76
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
77
+ else:
78
+ model = down + [submodule] + up
79
+
80
+ self.model = nn.Sequential(*model)
81
+
82
+ def forward(self, x):
83
+ if self.outermost:
84
+ return self.model(x)
85
+ else:
86
+ return torch.cat([x, self.model(x)], 1)
data/MBD/modify_stn_model/stn_head.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import math
4
+ import numpy as np
5
+ import sys
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init
11
+
12
+
13
+ def conv3x3_block(in_planes, out_planes, stride=1):
14
+ """3x3 convolution with padding"""
15
+ conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)
16
+
17
+ block = nn.Sequential(
18
+ conv_layer,
19
+ nn.BatchNorm2d(out_planes),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+ return block
23
+
24
+
25
+ class STNHead(nn.Module):
26
+ def __init__(self, in_planes, num_ctrlpoints, activation='none'):
27
+ super(STNHead, self).__init__()
28
+
29
+ self.in_planes = in_planes
30
+ self.num_ctrlpoints = num_ctrlpoints
31
+ self.activation = activation
32
+ self.stn_convnet = nn.Sequential(
33
+ conv3x3_block(in_planes, 32), # 32*64
34
+ nn.MaxPool2d(kernel_size=2, stride=2),
35
+ conv3x3_block(32, 64), # 16*32
36
+ nn.MaxPool2d(kernel_size=2, stride=2),
37
+ conv3x3_block(64, 128), # 8*16
38
+ nn.MaxPool2d(kernel_size=2, stride=2),
39
+ conv3x3_block(128, 256), # 4*8
40
+ nn.MaxPool2d(kernel_size=2, stride=2),
41
+ conv3x3_block(256, 256), # 2*4,
42
+ nn.MaxPool2d(kernel_size=2, stride=2),
43
+ conv3x3_block(256, 256)) # 1*2 > 256*8*8
44
+
45
+ self.stn_fc1 = nn.Sequential(
46
+ # nn.Linear(2*256, 512),
47
+ nn.Linear(8*8*256, 512),
48
+ nn.BatchNorm1d(512),
49
+ nn.ReLU(inplace=True))
50
+ self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2)
51
+
52
+ self.init_weights(self.stn_convnet)
53
+ self.init_weights(self.stn_fc1)
54
+ self.init_stn(self.stn_fc2)
55
+
56
+ def init_weights(self, module):
57
+ for m in module.modules():
58
+ if isinstance(m, nn.Conv2d):
59
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60
+ m.weight.data.normal_(0, math.sqrt(2. / n))
61
+ if m.bias is not None:
62
+ m.bias.data.zero_()
63
+ elif isinstance(m, nn.BatchNorm2d):
64
+ m.weight.data.fill_(1)
65
+ m.bias.data.zero_()
66
+ elif isinstance(m, nn.Linear):
67
+ m.weight.data.normal_(0, 0.001)
68
+ m.bias.data.zero_()
69
+
70
+ def init_stn(self, stn_fc2):
71
+ # margin = 0.01
72
+ # sampling_num_per_side = int(self.num_ctrlpoints / 2)
73
+ # ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side)
74
+ # ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
75
+ # ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin)
76
+ # ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
77
+ # ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
78
+ # ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
79
+
80
+ margin_x, margin_y = 0.35,0.35
81
+ # margin_x, margin_y = 0,0
82
+ num_ctrl_pts_per_side = (self.num_ctrlpoints-4) // 4 +2
83
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
84
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
85
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
86
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
87
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
88
+
89
+ ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
90
+ ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
91
+ ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
92
+ ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
93
+
94
+ ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0).astype(np.float32)
95
+
96
+
97
+ if self.activation is 'none':
98
+ pass
99
+ elif self.activation == 'sigmoid':
100
+ ctrl_points = -np.log(1. / ctrl_points - 1.)
101
+ stn_fc2.weight.data.zero_()
102
+ stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
103
+
104
+ def forward(self, x):
105
+ x = self.stn_convnet(x)
106
+ batch_size, _, h, w = x.size()
107
+ x = x.view(batch_size, -1)
108
+ img_feat = self.stn_fc1(x)
109
+ x = self.stn_fc2(0.1 * img_feat)
110
+ if self.activation == 'sigmoid':
111
+ x = F.sigmoid(x)
112
+ x = x.view(-1, self.num_ctrlpoints, 2)
113
+ return img_feat, x
114
+
115
+
116
+ if __name__ == "__main__":
117
+ in_planes = 3
118
+ num_ctrlpoints = 20
119
+ activation='none' # 'sigmoid'
120
+ stn_head = STNHead(in_planes, num_ctrlpoints, activation)
121
+ input = torch.randn(10, 3, 32, 64)
122
+ control_points = stn_head(input)
123
+ print(control_points.size())
data/MBD/modify_stn_model/tps_spatial_transformer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import numpy as np
4
+ import itertools
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ def grid_sample(input, grid, canvas = None):
11
+ output = F.grid_sample(input, grid)
12
+ if canvas is None:
13
+ return output
14
+ else:
15
+ input_mask = input.data.new(input.size()).fill_(1)
16
+ output_mask = F.grid_sample(input_mask, grid)
17
+ padded_output = output * output_mask + canvas * (1 - output_mask)
18
+ return padded_output
19
+
20
+
21
+ # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
22
+ def compute_partial_repr(input_points, control_points):
23
+ N = input_points.size(0)
24
+ M = control_points.size(0)
25
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
26
+ # original implementation, very slow
27
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
28
+ pairwise_diff_square = pairwise_diff * pairwise_diff
29
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
30
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
31
+ # fix numerical error for 0 * log(0), substitute all nan with 0
32
+ mask = repr_matrix != repr_matrix
33
+ repr_matrix.masked_fill_(mask, 0)
34
+ return repr_matrix
35
+
36
+
37
+ # # output_ctrl_pts are specified, according to our task.
38
+ # def build_output_control_points(num_control_points, margins):
39
+ # margin_x, margin_y = margins
40
+ # margin_x, margin_y = 0,0
41
+ # num_ctrl_pts_per_side = num_control_points // 2
42
+ # ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
43
+ # ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
44
+ # ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
45
+ # ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
46
+ # ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
47
+ # # ctrl_pts_top = ctrl_pts_top[1:-1,:]
48
+ # # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
49
+ # output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
50
+ # output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
51
+ # return output_ctrl_pts
52
+
53
+ # output_ctrl_pts are specified, according to our task.
54
+ # def build_output_control_points(num_control_points, margins):
55
+ # margin_x, margin_y = margins
56
+ # # margin_x, margin_y = 0,0
57
+ # num_ctrl_pts_per_side = (num_control_points-4) // 4 +2
58
+ # ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
59
+ # ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
60
+ # ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
61
+ # ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
62
+ # ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
63
+
64
+ # ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
65
+ # ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
66
+ # ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
67
+ # ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
68
+
69
+ # output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0)
70
+ # output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
71
+ # return output_ctrl_pts
72
+
73
+ def build_output_control_points(num_control_points, margins):
74
+ points = [0.25,0.5,0.75]
75
+ pts2 = [[0, 0],[1, 0], [0, 1],[1, 1]]
76
+ # pts22 = []
77
+ for ratio in points:
78
+ pts2.append([1*ratio,0])
79
+ for ratio in points:
80
+ pts2.append([1*ratio,1])
81
+ for ratio in points:
82
+ pts2.append([0,1*ratio])
83
+ for ratio in points:
84
+ pts2.append([1,1*ratio])
85
+ pts2 = np.float32(pts2)
86
+ margin_x, margin_y = margins
87
+ # margin_x, margin_y = 0,0
88
+ num_ctrl_pts_per_side = (num_control_points-4) // 4 +2
89
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
90
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
91
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
92
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
93
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
94
+
95
+ ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
96
+ ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
97
+ ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
98
+ ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
99
+
100
+ output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0)
101
+ # output_ctrl_pts_arr = np.asarray([[0,0],[1,0],[1,1],[0,1],
102
+ # [],[],[],[],
103
+ # [],[],[],[],
104
+ # [],[],[],[]])
105
+ output_ctrl_pts_arr = pts2
106
+ # print(output_ctrl_pts_arr.shape,'=================')
107
+ output_ctrl_pts = torch.FloatTensor(output_ctrl_pts_arr)
108
+ return output_ctrl_pts
109
+
110
+
111
+
112
+ # demo: ~/test/models/test_tps_transformation.py
113
+ class TPSSpatialTransformer(nn.Module):
114
+
115
+ def __init__(self, output_image_size=None, num_control_points=None, margins=None):
116
+ super(TPSSpatialTransformer, self).__init__()
117
+ self.output_image_size = output_image_size
118
+ self.num_control_points = num_control_points
119
+ self.margins = margins
120
+
121
+ self.target_height, self.target_width = output_image_size
122
+ target_control_points = build_output_control_points(num_control_points, margins)
123
+ N = num_control_points
124
+ # N = N - 4
125
+
126
+ # create padded kernel matrix
127
+ forward_kernel = torch.zeros(N + 3, N + 3)
128
+ target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
129
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
130
+ forward_kernel[:N, -3].fill_(1)
131
+ forward_kernel[-3, :N].fill_(1)
132
+ forward_kernel[:N, -2:].copy_(target_control_points)
133
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
134
+ # compute inverse matrix
135
+ # print(forward_kernel.shape)
136
+ inverse_kernel = torch.inverse(forward_kernel)
137
+
138
+ # create target cordinate matrix
139
+ HW = self.target_height * self.target_width
140
+ target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width)))
141
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
142
+ Y, X = target_coordinate.split(1, dim = 1)
143
+ Y = Y / (self.target_height - 1)
144
+ X = X / (self.target_width - 1)
145
+ target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
146
+ target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
147
+ target_coordinate_repr = torch.cat([
148
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
149
+ ], dim = 1)
150
+
151
+ # register precomputed matrices
152
+ self.register_buffer('inverse_kernel', inverse_kernel)
153
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
154
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
155
+ self.register_buffer('target_control_points', target_control_points)
156
+
157
+ def forward(self, input, source_control_points,direction='dewarp'):
158
+ if direction == 'dewarp':
159
+ assert source_control_points.ndimension() == 3
160
+ assert source_control_points.size(1) == self.num_control_points
161
+ assert source_control_points.size(2) == 2
162
+ batch_size = source_control_points.size(0)
163
+
164
+ Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1)
165
+ mapping_matrix = torch.matmul(self.inverse_kernel, Y)
166
+ source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
167
+
168
+ grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
169
+ grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
170
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
171
+ grid = 2.0 * grid - 1.0
172
+ output = grid_sample(input, grid, canvas=None)
173
+ return output, grid
174
+
175
+ # elif direction == 'warp':
176
+ # target_control_points = source_control_points.clone()
177
+ # source_control_points = (build_output_control_points(self.num_control_points, self.margins)).clone()
178
+ # source_control_points = source_control_points.unsqueeze(0)
179
+ # source_control_points = source_control_points.expand(target_control_points.size(0),self.num_control_points,2)
180
+ # assert source_control_points.ndimension() == 3
181
+ # assert source_control_points.size(1) == self.num_control_points
182
+ # assert source_control_points.size(2) == 2
183
+ # batch_size = source_control_points.size(0)
184
+
185
+ # Y = torch.cat([source_control_points.to('cuda'), self.padding_matrix.expand(batch_size, 3, 2)], 1)
186
+ # mapping_matrix = torch.matmul(self.inverse_kernel, Y)
187
+ # source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
188
+
189
+ # grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
190
+ # grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
191
+ # # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
192
+ # grid = 2.0 * grid - 1.0
193
+ # output_maps = grid_sample(input, grid, canvas=None)
194
+ # return output_maps, source_coordinate
data/MBD/stn_model/stn_head.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import math
4
+ import numpy as np
5
+ import sys
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init
11
+
12
+
13
+ def conv3x3_block(in_planes, out_planes, stride=1):
14
+ """3x3 convolution with padding"""
15
+ conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)
16
+
17
+ block = nn.Sequential(
18
+ conv_layer,
19
+ nn.BatchNorm2d(out_planes),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+ return block
23
+
24
+
25
+ class STNHead(nn.Module):
26
+ def __init__(self, in_planes, num_ctrlpoints, activation='none'):
27
+ super(STNHead, self).__init__()
28
+
29
+ self.in_planes = in_planes
30
+ self.num_ctrlpoints = num_ctrlpoints
31
+ self.activation = activation
32
+ self.stn_convnet = nn.Sequential(
33
+ conv3x3_block(in_planes, 32), # 32*64
34
+ nn.MaxPool2d(kernel_size=2, stride=2),
35
+ conv3x3_block(32, 64), # 16*32
36
+ nn.MaxPool2d(kernel_size=2, stride=2),
37
+ conv3x3_block(64, 128), # 8*16
38
+ nn.MaxPool2d(kernel_size=2, stride=2),
39
+ conv3x3_block(128, 256), # 4*8
40
+ nn.MaxPool2d(kernel_size=2, stride=2),
41
+ conv3x3_block(256, 256), # 2*4,
42
+ nn.MaxPool2d(kernel_size=2, stride=2),
43
+ conv3x3_block(256, 256)) # 1*2 > 256*8*8
44
+
45
+ self.stn_fc1 = nn.Sequential(
46
+ # nn.Linear(2*256, 512),
47
+ nn.Linear(8*8*256, 512),
48
+ nn.BatchNorm1d(512),
49
+ nn.ReLU(inplace=True))
50
+ self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2)
51
+
52
+ self.init_weights(self.stn_convnet)
53
+ self.init_weights(self.stn_fc1)
54
+ self.init_stn(self.stn_fc2)
55
+
56
+ def init_weights(self, module):
57
+ for m in module.modules():
58
+ if isinstance(m, nn.Conv2d):
59
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60
+ m.weight.data.normal_(0, math.sqrt(2. / n))
61
+ if m.bias is not None:
62
+ m.bias.data.zero_()
63
+ elif isinstance(m, nn.BatchNorm2d):
64
+ m.weight.data.fill_(1)
65
+ m.bias.data.zero_()
66
+ elif isinstance(m, nn.Linear):
67
+ m.weight.data.normal_(0, 0.001)
68
+ m.bias.data.zero_()
69
+
70
+ def init_stn(self, stn_fc2):
71
+ # margin = 0.01
72
+ # sampling_num_per_side = int(self.num_ctrlpoints / 2)
73
+ # ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side)
74
+ # ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
75
+ # ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin)
76
+ # ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
77
+ # ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
78
+ # ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
79
+
80
+ margin_x, margin_y = 0.35,0.35
81
+ # margin_x, margin_y = 0,0
82
+ num_ctrl_pts_per_side = (self.num_ctrlpoints-4) // 4 +2
83
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
84
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
85
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
86
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
87
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
88
+
89
+ ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
90
+ ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
91
+ ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
92
+ ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
93
+
94
+ ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0).astype(np.float32)
95
+
96
+
97
+ if self.activation is 'none':
98
+ pass
99
+ elif self.activation == 'sigmoid':
100
+ ctrl_points = -np.log(1. / ctrl_points - 1.)
101
+ stn_fc2.weight.data.zero_()
102
+ stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
103
+
104
+ def forward(self, x):
105
+ x = self.stn_convnet(x)
106
+ batch_size, _, h, w = x.size()
107
+ x = x.view(batch_size, -1)
108
+ img_feat = self.stn_fc1(x)
109
+ x = self.stn_fc2(0.1 * img_feat)
110
+ if self.activation == 'sigmoid':
111
+ x = F.sigmoid(x)
112
+ x = x.view(-1, self.num_ctrlpoints, 2)
113
+ return img_feat, x
114
+
115
+
116
+ if __name__ == "__main__":
117
+ in_planes = 3
118
+ num_ctrlpoints = 20
119
+ activation='none' # 'sigmoid'
120
+ stn_head = STNHead(in_planes, num_ctrlpoints, activation)
121
+ input = torch.randn(10, 3, 32, 64)
122
+ control_points = stn_head(input)
123
+ print(control_points.size())
data/MBD/stn_model/tps_spatial_transformer.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import numpy as np
4
+ import itertools
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ def grid_sample(input, grid, canvas = None):
11
+ output = F.grid_sample(input, grid)
12
+ if canvas is None:
13
+ return output
14
+ else:
15
+ input_mask = input.data.new(input.size()).fill_(1)
16
+ output_mask = F.grid_sample(input_mask, grid)
17
+ padded_output = output * output_mask + canvas * (1 - output_mask)
18
+ return padded_output
19
+
20
+
21
+ # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
22
+ def compute_partial_repr(input_points, control_points):
23
+ N = input_points.size(0)
24
+ M = control_points.size(0)
25
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
26
+ # original implementation, very slow
27
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
28
+ pairwise_diff_square = pairwise_diff * pairwise_diff
29
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
30
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
31
+ # fix numerical error for 0 * log(0), substitute all nan with 0
32
+ mask = repr_matrix != repr_matrix
33
+ repr_matrix.masked_fill_(mask, 0)
34
+ return repr_matrix
35
+
36
+
37
+ # # output_ctrl_pts are specified, according to our task.
38
+ # def build_output_control_points(num_control_points, margins):
39
+ # margin_x, margin_y = margins
40
+ # margin_x, margin_y = 0,0
41
+ # num_ctrl_pts_per_side = num_control_points // 2
42
+ # ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
43
+ # ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
44
+ # ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
45
+ # ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
46
+ # ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
47
+ # # ctrl_pts_top = ctrl_pts_top[1:-1,:]
48
+ # # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
49
+ # output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
50
+ # output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
51
+ # return output_ctrl_pts
52
+
53
+ # output_ctrl_pts are specified, according to our task.
54
+ def build_output_control_points(num_control_points, margins):
55
+ margin_x, margin_y = margins
56
+ # margin_x, margin_y = 0,0
57
+ num_ctrl_pts_per_side = (num_control_points-4) // 4 +2
58
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
59
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
60
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
61
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
62
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
63
+
64
+ ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
65
+ ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
66
+ ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
67
+ ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
68
+
69
+ output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0)
70
+ output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
71
+ return output_ctrl_pts
72
+
73
+ # demo: ~/test/models/test_tps_transformation.py
74
+ class TPSSpatialTransformer(nn.Module):
75
+
76
+ def __init__(self, output_image_size=None, num_control_points=None, margins=None):
77
+ super(TPSSpatialTransformer, self).__init__()
78
+ self.output_image_size = output_image_size
79
+ self.num_control_points = num_control_points
80
+ self.margins = margins
81
+
82
+ self.target_height, self.target_width = output_image_size
83
+ target_control_points = build_output_control_points(num_control_points, margins)
84
+ N = num_control_points
85
+ # N = N - 4
86
+
87
+ # create padded kernel matrix
88
+ forward_kernel = torch.zeros(N + 3, N + 3)
89
+ target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
90
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
91
+ forward_kernel[:N, -3].fill_(1)
92
+ forward_kernel[-3, :N].fill_(1)
93
+ forward_kernel[:N, -2:].copy_(target_control_points)
94
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
95
+ # compute inverse matrix
96
+ # print(forward_kernel.shape)
97
+ inverse_kernel = torch.inverse(forward_kernel)
98
+
99
+ # create target cordinate matrix
100
+ HW = self.target_height * self.target_width
101
+ target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width)))
102
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
103
+ Y, X = target_coordinate.split(1, dim = 1)
104
+ Y = Y / (self.target_height - 1)
105
+ X = X / (self.target_width - 1)
106
+ target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
107
+ target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
108
+ target_coordinate_repr = torch.cat([
109
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
110
+ ], dim = 1)
111
+
112
+ # register precomputed matrices
113
+ self.register_buffer('inverse_kernel', inverse_kernel)
114
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
115
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
116
+ self.register_buffer('target_control_points', target_control_points)
117
+
118
+ def forward(self, input, source_control_points,direction='dewarp'):
119
+ if direction == 'dewarp':
120
+ assert source_control_points.ndimension() == 3
121
+ assert source_control_points.size(1) == self.num_control_points
122
+ assert source_control_points.size(2) == 2
123
+ batch_size = source_control_points.size(0)
124
+
125
+ Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1)
126
+ mapping_matrix = torch.matmul(self.inverse_kernel, Y)
127
+ source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
128
+
129
+ grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
130
+ grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
131
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
132
+ grid = 2.0 * grid - 1.0
133
+ output_maps = grid_sample(input, grid, canvas=None)
134
+ return output_maps, source_coordinate
135
+
136
+ # elif direction == 'warp':
137
+ # target_control_points = source_control_points.clone()
138
+ # source_control_points = (build_output_control_points(self.num_control_points, self.margins)).clone()
139
+ # source_control_points = source_control_points.unsqueeze(0)
140
+ # source_control_points = source_control_points.expand(target_control_points.size(0),self.num_control_points,2)
141
+ # assert source_control_points.ndimension() == 3
142
+ # assert source_control_points.size(1) == self.num_control_points
143
+ # assert source_control_points.size(2) == 2
144
+ # batch_size = source_control_points.size(0)
145
+
146
+ # Y = torch.cat([source_control_points.to('cuda'), self.padding_matrix.expand(batch_size, 3, 2)], 1)
147
+ # mapping_matrix = torch.matmul(self.inverse_kernel, Y)
148
+ # source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
149
+
150
+ # grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
151
+ # grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
152
+ # # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
153
+ # grid = 2.0 * grid - 1.0
154
+ # output_maps = grid_sample(input, grid, canvas=None)
155
+ # return output_maps, source_coordinate
data/MBD/tps_grid_gen.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding: utf-8
2
+
3
+ import torch
4
+ import itertools
5
+ import torch.nn as nn
6
+ from torch.autograd import Function, Variable
7
+
8
+ class TPSGridGen(nn.Module):
9
+
10
+ def __init__(self, target_height, target_width, target_control_points):
11
+ super(TPSGridGen, self).__init__()
12
+ assert target_control_points.ndimension() == 2
13
+ assert target_control_points.size(1) == 2
14
+ N = target_control_points.size(0)
15
+ self.num_points = N
16
+ target_control_points = target_control_points.float()
17
+
18
+ # create padded kernel matrix
19
+ forward_kernel = torch.zeros(N + 3, N + 3)
20
+ target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
21
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
22
+ forward_kernel[:N, -3].fill_(1)
23
+ forward_kernel[-3, :N].fill_(1)
24
+ forward_kernel[:N, -2:].copy_(target_control_points)
25
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
26
+ # compute inverse matrix
27
+ inverse_kernel = torch.inverse(forward_kernel)
28
+
29
+ # create target cordinate matrix
30
+ HW = target_height * target_width
31
+ target_coordinate = list(itertools.product(range(target_height), range(target_width)))
32
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
33
+ Y, X = target_coordinate.split(1, dim = 1)
34
+ Y = Y * 2 / (target_height - 1) - 1
35
+ X = X * 2 / (target_width - 1) - 1
36
+ target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
37
+ target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate, target_control_points)
38
+ target_coordinate_repr = torch.cat([
39
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
40
+ ], dim = 1)
41
+
42
+ # register precomputed matrices
43
+ self.register_buffer('inverse_kernel', inverse_kernel)
44
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
45
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
46
+
47
+ def forward(self, source_control_points):
48
+ assert source_control_points.ndimension() == 3
49
+ assert source_control_points.size(1) == self.num_points
50
+ assert source_control_points.size(2) == 2
51
+ batch_size = source_control_points.size(0)
52
+
53
+ Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
54
+ mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
55
+ source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
56
+ return source_coordinate
57
+ # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
58
+ def compute_partial_repr(self, input_points, control_points):
59
+ N = input_points.size(0)
60
+ M = control_points.size(0)
61
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
62
+ # original implementation, very slow
63
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
64
+ pairwise_diff_square = pairwise_diff * pairwise_diff
65
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
66
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
67
+ # fix numerical error for 0 * log(0), substitute all nan with 0
68
+ mask = repr_matrix != repr_matrix
69
+ repr_matrix.masked_fill_(mask, 0)
70
+ return repr_matrix
data/MBD/utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Misc Utility functions
3
+ '''
4
+ from collections import OrderedDict
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import random
9
+ import torchvision
10
+
11
+ def recursive_glob(rootdir='.', suffix=''):
12
+ """Performs recursive glob with given suffix and rootdir
13
+ :param rootdir is the root directory
14
+ :param suffix is the suffix to be searched
15
+ """
16
+ return [os.path.join(looproot, filename)
17
+ for looproot, _, filenames in os.walk(rootdir)
18
+ for filename in filenames if filename.endswith(suffix)]
19
+
20
+ def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
21
+ """Polynomial decay of learning rate
22
+ :param init_lr is base learning rate
23
+ :param iter is a current iteration
24
+ :param lr_decay_iter how frequently decay occurs, default is 1
25
+ :param max_iter is number of maximum iterations
26
+ :param power is a polymomial power
27
+
28
+ """
29
+ if iter % lr_decay_iter or iter > max_iter:
30
+ return optimizer
31
+
32
+ for param_group in optimizer.param_groups:
33
+ param_group['lr'] = init_lr*(1 - iter/max_iter)**power
34
+
35
+
36
+ def adjust_learning_rate(optimizer, init_lr, epoch):
37
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
38
+ lr = init_lr * (0.1 ** (epoch // 30))
39
+ for param_group in optimizer.param_groups:
40
+ param_group['lr'] = lr
41
+
42
+
43
+ def alpha_blend(input_image, segmentation_mask, alpha=0.5):
44
+ """Alpha Blending utility to overlay RGB masks on RBG images
45
+ :param input_image is a np.ndarray with 3 channels
46
+ :param segmentation_mask is a np.ndarray with 3 channels
47
+ :param alpha is a float value
48
+
49
+ """
50
+ blended = np.zeros(input_image.size, dtype=np.float32)
51
+ blended = input_image * alpha + segmentation_mask * (1 - alpha)
52
+ return blended
53
+
54
+ def convert_state_dict(state_dict):
55
+ """Converts a state dict saved from a dataParallel module to normal
56
+ module state_dict inplace
57
+ :param state_dict is the loaded DataParallel model_state
58
+
59
+ """
60
+ new_state_dict = OrderedDict()
61
+ for k, v in state_dict.items():
62
+ name = k[7:] # remove `module.`
63
+ new_state_dict[name] = v
64
+ return new_state_dict
65
+
66
+
67
+ class ImagePool():
68
+ def __init__(self, pool_size):
69
+ self.pool_size = pool_size
70
+ if self.pool_size > 0:
71
+ self.num_imgs = 0
72
+ self.images = []
73
+
74
+ def query(self, images):
75
+ if self.pool_size == 0:
76
+ return images
77
+ return_images = []
78
+ for image in images:
79
+ image = torch.unsqueeze(image.data, 0)
80
+ if self.num_imgs < self.pool_size:
81
+ self.num_imgs = self.num_imgs + 1
82
+ self.images.append(image)
83
+ return_images.append(image)
84
+ else:
85
+ p = random.uniform(0, 1)
86
+ if p > 0.5:
87
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
88
+ tmp = self.images[random_id].clone()
89
+ self.images[random_id] = image
90
+ return_images.append(tmp)
91
+ else:
92
+ return_images.append(image)
93
+ return_images = torch.cat(return_images, 0)
94
+ return return_images
95
+
96
+
97
+ def set_requires_grad(nets, requires_grad=False):
98
+ if not isinstance(nets, list):
99
+ nets = [nets]
100
+ for net in nets:
101
+ if net is not None:
102
+ for param in net.parameters():
103
+ param.requires_grad = requires_grad
104
+
105
+
106
+
107
+ def get_lr(optimizer):
108
+ for param_group in optimizer.param_groups:
109
+ return float(param_group['lr'])
110
+
111
+ def visualize(epoch,model,layer):
112
+ #get conv layers
113
+ conv_layers=[]
114
+ for m in model.modules():
115
+ if isinstance(m,torch.nn.modules.conv.Conv2d):
116
+ conv_layers.append(m)
117
+
118
+ # print conv_layers[layer].weight.data.cpu().numpy().shape
119
+ tensor=conv_layers[layer].weight.data.cpu()
120
+ vistensor(tensor, epoch, ch=0, allkernels=False, nrow=8, padding=1)
121
+
122
+
123
+ def vistensor(tensor, epoch, ch=0, allkernels=False, nrow=8, padding=1):
124
+ '''
125
+ vistensor: visuzlization tensor
126
+ @ch: visualization channel
127
+ @allkernels: visualization all tensors
128
+ https://github.com/pedrodiamel/pytorchvision/blob/a14672fe4b07995e99f8af755de875daf8aababb/pytvision/visualization.py#L325
129
+ '''
130
+
131
+ n,c,w,h = tensor.shape
132
+ if allkernels: tensor = tensor.view(n*c,-1,w,h )
133
+ elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)
134
+
135
+ rows = np.min( (tensor.shape[0]//nrow + 1, 64 ) )
136
+ # print rows
137
+ # print tensor.shape
138
+ grid = utils.make_grid(tensor, nrow=8, normalize=True, padding=padding)
139
+ # print grid.shape
140
+ plt.figure( figsize=(10,10), dpi=200 )
141
+ plt.imshow(grid.numpy().transpose((1, 2, 0)))
142
+ plt.savefig('./generated/filters_layer1_dwuv_'+str(epoch)+'.png')
143
+ plt.close()
144
+
145
+
146
+ def show_uloss(uwpred,uworg,inp_img, samples=7):
147
+
148
+ n,c,h,w=inp_img.shape
149
+ # print(labels.shape)
150
+ uwpred=uwpred.detach().cpu().numpy()
151
+ uworg=uworg.detach().cpu().numpy()
152
+ inp_img=inp_img.detach().cpu().numpy()
153
+
154
+ #NCHW->NHWC
155
+ uwpred=uwpred.transpose((0, 2, 3, 1))
156
+ uworg=uworg.transpose((0, 2, 3, 1))
157
+
158
+ choices=random.sample(range(n), min(n,samples))
159
+ f, axarr = plt.subplots(samples, 3)
160
+ for j in range(samples):
161
+ # print(np.min(labels[j]))
162
+ # print imgs[j].shape
163
+ img=inp_img[j].transpose(1,2,0)
164
+ axarr[j][0].imshow(img[:,:,::-1])
165
+ axarr[j][1].imshow(uworg[j])
166
+ axarr[j][2].imshow(uwpred[j])
167
+
168
+ plt.savefig('./generated/unwarp.png')
169
+ plt.close()
170
+
171
+
172
+ def show_uloss_visdom(vis,uwpred,uworg,labels_win,out_win,labelopts,outopts,args):
173
+ samples=7
174
+ n,c,h,w=uwpred.shape
175
+ uwpred=uwpred.detach().cpu().numpy()
176
+ uworg=uworg.detach().cpu().numpy()
177
+ out_arr=np.full((samples,3,args.img_rows,args.img_cols),0.0)
178
+ label_arr=np.full((samples,3,args.img_rows,args.img_cols),0.0)
179
+ choices=random.sample(range(n), min(n,samples))
180
+ idx=0
181
+ for c in choices:
182
+ out_arr[idx,:,:,:]=uwpred[c]
183
+ label_arr[idx,:,:,:]=uworg[c]
184
+ idx+=1
185
+
186
+ vis.images(out_arr,
187
+ win=out_win,
188
+ opts=outopts)
189
+ vis.images(label_arr,
190
+ win=labels_win,
191
+ opts=labelopts)
192
+
193
+ def show_unwarp_tnsboard(global_step,writer,uwpred,uworg,grid_samples,gt_tag,pred_tag):
194
+ idxs=torch.LongTensor(random.sample(range(images.shape[0]), min(grid_samples,images.shape[0])))
195
+ grid_uworg = torchvision.utils.make_grid(uworg[idxs],normalize=True, scale_each=True)
196
+ writer.add_image(gt_tag, grid_uworg, global_step)
197
+ grid_uwpr = torchvision.utils.make_grid(uwpred[idxs],normalize=True, scale_each=True)
198
+ writer.add_image(pred_tag, grid_uwpr, global_step)
199
+
200
+ def show_wc_tnsboard(global_step,writer,images,labels, pred, grid_samples,inp_tag, gt_tag, pred_tag):
201
+ idxs=torch.LongTensor(random.sample(range(images.shape[0]), min(grid_samples,images.shape[0])))
202
+ grid_inp = torchvision.utils.make_grid(images[idxs],normalize=True, scale_each=True)
203
+ writer.add_image(inp_tag, grid_inp, global_step)
204
+ grid_lbl = torchvision.utils.make_grid(labels[idxs],normalize=True, scale_each=True)
205
+ writer.add_image(gt_tag, grid_lbl, global_step)
206
+ grid_pred = torchvision.utils.make_grid(pred[idxs],normalize=True, scale_each=True)
207
+ writer.add_image(pred_tag, grid_pred, global_step)
208
+ def torch2cvimg(tensor,min=0,max=1):
209
+ '''
210
+ input:
211
+ tensor -> torch.tensor BxCxHxW C can be 1,3
212
+ return
213
+ im -> ndarray uint8 HxWxC
214
+ '''
215
+ im_list = []
216
+ for i in range(tensor.shape[0]):
217
+ im = tensor.detach().cpu().data.numpy()[i]
218
+ im = im.transpose(1,2,0)
219
+ im = np.clip(im,min,max)
220
+ im = ((im-min)/(max-min)*255).astype(np.uint8)
221
+ im_list.append(im)
222
+ return im_list
223
+ def cvimg2torch(img,min=0,max=1):
224
+ '''
225
+ input:
226
+ im -> ndarray uint8 HxWxC
227
+ return
228
+ tensor -> torch.tensor BxCxHxW
229
+ '''
230
+ img = img.astype(float) / 255.0
231
+ img = img.transpose(2, 0, 1) # NHWC -> NCHW
232
+ img = np.expand_dims(img, 0)
233
+ img = torch.from_numpy(img).float()
234
+ return img
data/README.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Preparation
2
+ The data files tree should be look like:
3
+ ```
4
+ data/
5
+ eval/
6
+ dir300/
7
+ 1_in.png
8
+ 1_gt.png
9
+ ...
10
+ kligler/
11
+ jung/
12
+ osr/
13
+ realdae/
14
+ docunet_docaligner/
15
+ dibco18/
16
+ train/
17
+ dewarping/
18
+ doc3d/
19
+ deshadowing/
20
+ fsdsrd/
21
+ tdd/
22
+ appearance/
23
+ clean_pdfs/
24
+ realdae/
25
+ deblurring/
26
+ tdd/
27
+ binarization/
28
+ bickly/
29
+ dibco/
30
+ noise_office/
31
+ phibd/
32
+ msi/
33
+ ```
34
+
35
+ ## Evaluation Dataset
36
+ You can find the links for downloading the dataset we used for evaluation (Tables 1 and 2) in [this](https://github.com/ZZZHANG-jx/Recommendations-Document-Image-Processing/tree/master) repository, including DIR300 (300 samples), Kligler (300 samples), Jung (87 samples), OSR (237 samples), RealDAE (150 samples), DocUNet_DocAligner (150 samples), TDD (16000 samples) and DIBCO18 (10 samples). After downloading, add the suffix of `_in` and `_gt` to the input image and gt image respectively, and place them in the folder of the corresponding dataset
37
+
38
+
39
+ ## Training Dataset
40
+ You can find the links for downloading the dataset we used for training in [this](https://github.com/ZZZHANG-jx/Recommendations-Document-Image-Processing/tree/master) repository.
41
+ ### Dewarping
42
+ - Doc3D
43
+ - Mask extraction: you should extract the mask for each image from the uv data in Doc3D
44
+ - Background preparation: you can download the background data from [here](https://www.robots.ox.ac.uk/~vgg/data/dtd/) and specify it for self.background_paths in `loaders/docres_loader.py`
45
+ - JSON preparation:
46
+ ```
47
+
48
+ [
49
+ ## you need to specify the paths of 'in_path', 'mask_path and 'gt_path':
50
+ {
51
+ "in_path": "dewarping/doc3d/img/1/102_1-pp_Page_048-xov0001.png",
52
+ "mask_path": "dewarping/doc3d/mask/1/102_1-pp_Page_048-xov0001.png",
53
+ "gt_path": "dewarping/doc3d/bm/1/102_1-pp_Page_048-xov0001.npy"
54
+ }
55
+ ]
56
+
57
+ ```
58
+ ### Deshadowing
59
+ - RDD
60
+ - FSDSRD
61
+ - JSON preparation
62
+ ```
63
+ [ ## you need to specify the paths of 'in_path' and 'gt_path', for example:
64
+ {
65
+ "in_path": "deshadowing/fsdsrd/im/00004.png",
66
+ "gt_path": "deshadowing/fsdsrd/gt/00004.png"
67
+ },
68
+ {
69
+ "in_path": "deshadowing/rdd/im/00004.png",
70
+ "gt_path": "deshadowing/rdd/gt/00004.png"
71
+ }
72
+ ]
73
+ ```
74
+ ### Appearance enhancement
75
+ - Doc3DShade
76
+ - Clean PDFs collection: You should collection PDFs files from the internet and convert them as images to serve as the source for synthesis.
77
+ - Extract shadows from Doc3DShade by using `data/preprocess/shadow_extract.py` and dewarp the obtained shadows by using `data/MBD/infer.py`. Then you should specify self.shadow_paths in `loaders/docres_loader.py`
78
+ - RealDAE
79
+ - JSON preparation:
80
+ ```
81
+ [
82
+ ## for Doc3DShade dataset, you only need to specify the path of image from PDF, for example:
83
+ {
84
+ 'gt_path':'appearance/clean_pdfs/1.jpg'
85
+ },
86
+
87
+ ## for RealDAE dataset, you need to specify the paths of both input and gt, for example:
88
+ {
89
+ 'in_path': 'appearance/realdae/1_in.jpg',
90
+ 'gt_path': 'appearance/realdae/1_gt.jpg'
91
+ }
92
+ ]
93
+
94
+ ```
95
+
96
+ ### Debluring
97
+ - TDD
98
+ - JSON preparation
99
+ ```
100
+ [ ## you need to specify the paths of 'in_path' and 'gt_path', for example:
101
+ {
102
+ "in_path": "debluring/tdd/im/00004.png",
103
+ "gt_path": "debluring/tdd/gt/00004.png"
104
+ },
105
+ ]
106
+ ```
107
+ ### Binarization
108
+ - Bickly
109
+ - DTPrompt preparation: Since the DTPrompt for binarization is time-expensive, we obtain it offline before training. Use `data/preprocess/sauvola_binarize.py`
110
+ - DIBCO
111
+ - DTPrompt preparation: the same as Bickly
112
+ - Noise Office
113
+ - DTPrompt preparation: the same as Bickly
114
+ - PHIDB
115
+ - DTPrompt preparation: the same as Bickly
116
+ - MSI
117
+ - DTPrompt preparation: the same as Bickly
118
+ - JSON preparation
119
+ ```
120
+ [
121
+ ## you need to specify the paths of 'in_path', 'gt_path', 'bin_path', 'thr_path' and 'gradient_path', for example:
122
+ {
123
+ "in_path": "binarization/noise_office/imgs/1.png",
124
+ "gt_path": "binarization/noise_office/gt_imgs/1.png",
125
+ "bin_path": "binarization/noise_office/imgs/1_bin.png",
126
+ "thr_path": "binarization/noise_office/imgs/1_thr.png",
127
+ "gradient_path": "binarization/noise_office/imgs/1_gradient.png"
128
+ },
129
+ ]
130
+ ```
131
+
132
+ After all the data are prepared, you should specify the dataset_setting in `train.py`.
133
+
134
+
135
+
data/preprocess/crop_merge_image.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ # SIZE =256
6
+ # BATCH_SIZE = 32
7
+ # STRIDES = 256
8
+
9
+ def split_img(img, size_x, size_y, strides):
10
+ max_y, max_x = img.shape[:2]
11
+ border_y = 0
12
+ if max_y % size_y != 0:
13
+ border_y = size_y - (max_y % size_y)
14
+ img = cv2.copyMakeBorder(img,border_y,0,0,0,cv2.BORDER_REPLICATE)
15
+ # img = cv2.copyMakeBorder(img, border_y, 0, 0, 0, cv2.BORDER_CONSTANT, value=[255,255,255])
16
+ border_x = 0
17
+ if max_x % size_x != 0:
18
+ border_x = size_x - (max_x % size_x)
19
+ # img = cv2.copyMakeBorder(img, 0, 0, border_x, 0, cv2.BORDER_CONSTANT, value=[255,255,255])
20
+ img = cv2.copyMakeBorder(img,0,0,border_x,0,cv2.BORDER_REPLICATE)
21
+ # h,w
22
+ max_y, max_x = img.shape[:2]
23
+ parts = []
24
+ curr_y = 0
25
+ x = 0
26
+ y = 0
27
+ # TODO: rewrite with generators.
28
+ while (curr_y + size_y) <= max_y:
29
+ curr_x = 0
30
+ while (curr_x + size_x) <= max_x:
31
+ parts.append(img[curr_y:curr_y + size_y, curr_x:curr_x + size_x])
32
+ curr_x += strides
33
+ y += 1
34
+ curr_y += strides
35
+ # parts is a list
36
+ # (windows_number_x*windows_number_y,SIZE,SIZE,3)
37
+ # print(max_y,max_x)
38
+ # print(y,x)
39
+ # print(np.array(parts).shape)
40
+ return parts, border_x, border_y, max_x, max_y
41
+
42
+
43
+ def combine_imgs(border_x,border_y,imgs, max_y, max_x,size_x, size_y, strides):
44
+
45
+ # weighted_img
46
+
47
+ index = int(size_x / strides)
48
+ weight_img = np.ones(shape=(max_y,max_x))
49
+ weight_img[0:strides] = index
50
+ weight_img[-strides:] = index
51
+ weight_img[:,0:strides]=index
52
+ weight_img[:,-strides:]=index
53
+
54
+ # 边上
55
+ i = 0
56
+ for j in range(1,index+1):
57
+ # 左上
58
+ weight_img[0:strides,i:i+strides] = np.ones(shape=(strides,strides))*j
59
+ weight_img[i:i+strides,0:strides] = np.ones(shape=(strides,strides))*j
60
+ # 右上
61
+ weight_img[i:i+strides,-strides:] = np.ones(shape=(strides,strides))*j
62
+ if i == 0:
63
+ weight_img[0:strides,-strides:] = np.ones(shape=(strides,strides))*j
64
+ else:
65
+ weight_img[0:strides,-strides-i:-i] = np.ones(shape=(strides,strides))*j
66
+ # 左下
67
+ weight_img[-strides:,i:i+strides] = np.ones(shape=(strides,strides))*j
68
+ if i == 0:
69
+ weight_img[-strides:,0:strides] = np.ones(shape=(strides,strides))*j
70
+ else:
71
+ weight_img[-strides-i:-i:,0:strides] = np.ones(shape=(strides,strides))*j
72
+ # 右下
73
+ if i == 0:
74
+ weight_img[-strides:,-strides:] = np.ones(shape=(strides,strides))*j
75
+ else:
76
+ weight_img[-strides-i:-i,-strides:] = np.ones(shape=(strides,strides))*j
77
+ weight_img[-strides:,-strides-i:-i] = np.ones(shape=(strides,strides))*j
78
+
79
+
80
+ i += strides
81
+
82
+ for i in range(strides,max_y-strides,strides):
83
+ for j in range(strides,max_x-strides,strides):
84
+ weight_img[i:i+strides,j:j+strides] = np.ones(shape=(strides,strides))*weight_img[i][0]*weight_img[0][j]
85
+
86
+
87
+ if len(imgs[0].shape)==2:
88
+ new_img = np.zeros(shape=(max_y,max_x))
89
+ weight_img = (1 / weight_img)
90
+ else:
91
+ new_img = np.zeros(shape=(max_y,max_x,imgs[0].shape[-1]))
92
+ weight_img = (1 / weight_img).reshape((max_y,max_x,1))
93
+ weight_img = np.tile(weight_img,(1,1,imgs[0].shape[-1]))
94
+
95
+ curr_y = 0
96
+ x = 0
97
+ y = 0
98
+ i = 0
99
+ # TODO: rewrite with generators.
100
+ while (curr_y + size_y) <= max_y:
101
+ curr_x = 0
102
+ while (curr_x + size_x) <= max_x:
103
+ new_img[curr_y:curr_y + size_y, curr_x:curr_x + size_x] += weight_img[curr_y:curr_y + size_y, curr_x:curr_x + size_x]*imgs[i]
104
+ i += 1
105
+ curr_x += strides
106
+ y += 1
107
+ curr_y += strides
108
+
109
+
110
+ new_img = new_img[border_y:, border_x:]
111
+ # print(border_y,border_x)
112
+
113
+ return new_img
114
+
115
+
116
+ def stride_integral(img,stride=32):
117
+ h,w = img.shape[:2]
118
+
119
+ if (h%stride)!=0:
120
+ padding_h = stride - (h%stride)
121
+ img = cv2.copyMakeBorder(img,padding_h,0,0,0,borderType=cv2.BORDER_REPLICATE)
122
+ else:
123
+ padding_h = 0
124
+
125
+ if (w%stride)!=0:
126
+ padding_w = stride - (w%stride)
127
+ img = cv2.copyMakeBorder(img,0,0,padding_w,0,borderType=cv2.BORDER_REPLICATE)
128
+ else:
129
+ padding_w = 0
130
+
131
+ return img,padding_h,padding_w
132
+
133
+
134
+ def mkdir_s(path: str):
135
+ """Create directory in specified path, if not exists."""
136
+ if not os.path.exists(path):
137
+ os.makedirs(path)
138
+
139
+
140
+ if __name__ =='__main__':
141
+ parts, border_x, border_y, max_x, max_y = split_img(im,512,512,strides=512)
142
+ result = combine_imgs(border_x,border_y,parts, max_y, max_x,512, 512, 512)
data/preprocess/sauvola_binarize.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ # importing required libraries
3
+ import numpy as np
4
+ import cv2
5
+ from skimage.filters import threshold_sauvola
6
+ import glob
7
+ from tqdm import tqdm
8
+ import os
9
+ from skimage import io
10
+
11
+ def SauvolaModBinarization(image,n1=51,n2=51,k1=0.3,k2=0.3,default=True):
12
+ '''
13
+ Binarization using Sauvola's algorithm
14
+ @name : SauvolaModBinarization
15
+ parameters
16
+ @param image (numpy array of shape (3/1) of type np.uint8): color or gray scale image
17
+ optional parameters
18
+ @param n1 (int) : window size for running sauvola during the first pass
19
+ @param n2 (int): window size for running sauvola during the second pass
20
+ @param k1 (float): k value corresponding to sauvola during the first pass
21
+ @param k2 (float): k value corresponding to sauvola during the second pass
22
+ @param default (bool) : bollean variable to set the above parameter as default.
23
+ @param default is set to True : thus default values of the above optional parameters (n1,n2,k1,k2) are set to
24
+ n1 = 5 % of min(image height, image width)
25
+ n2 = 10 % of min(image height, image width)
26
+ k1 = 0.5
27
+ k2 = 0.5
28
+ Returns
29
+ @return A binary image of same size as @param image
30
+
31
+ @cite https://drive.google.com/file/d/1D3CyI5vtodPJeZaD2UV5wdcaIMtkBbdZ/view?usp=sharing
32
+ '''
33
+
34
+ if(default):
35
+ n1 = int(0.05*min(image.shape[0],image.shape[1]))
36
+ if (n1%2==0):
37
+ n1 = n1+1
38
+ n2 = int(0.1*min(image.shape[0],image.shape[1]))
39
+ if (n2%2==0):
40
+ n2 = n2+1
41
+ k1 = 0.5
42
+ k2 = 0.5
43
+ if(image.ndim==3):
44
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
45
+ else:
46
+ gray = np.copy(image)
47
+ T1 = threshold_sauvola(gray, window_size=n1,k=k1)
48
+ max_val = np.amax(gray)
49
+ min_val = np.amin(gray)
50
+ C = np.copy(T1)
51
+ C = C.astype(np.float32)
52
+ C[gray > T1] = (gray[gray > T1] - T1[gray > T1])/(max_val - T1[gray > T1])
53
+ C[gray <= T1] = 0
54
+ C = C * 255.0
55
+ new_in = np.copy(C.astype(np.uint8))
56
+ T2 = threshold_sauvola(new_in, window_size=n2,k=k2)
57
+ binary = np.copy(gray)
58
+ binary[new_in <= T2] = 0
59
+ binary[new_in > T2] = 255
60
+ return binary,T2
61
+
62
+
63
+ def dtprompt(img):
64
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
65
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
66
+ absX = cv2.convertScaleAbs(x) # 转回uint8
67
+ absY = cv2.convertScaleAbs(y)
68
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
69
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
70
+ return high_frequency
71
+
72
+
73
+ im_paths = glob.glob('imgs/*')
74
+
75
+
76
+
77
+ for im_path in tqdm(im_paths):
78
+ if '_bin.' in im_path:
79
+ continue
80
+ if '_thr.' in im_path:
81
+ continue
82
+ if '_gradient.' in im_path:
83
+ continue
84
+
85
+ im = cv2.imread(im_path)
86
+ result,thresh = SauvolaModBinarization(im)
87
+ gradient = dtprompt(im)
88
+ thresh = thresh.astype(np.uint8)
89
+ cv2.imwrite(im_path.replace('.','_bin.'),result)
90
+ cv2.imwrite(im_path.replace('.','_thr.'),thresh)
91
+ cv2.imwrite(im_path.replace('.','_gradient.'),gradient)
data/preprocess/shadow_extraction.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import glob
4
+ import os
5
+ from tqdm import tqdm
6
+ import random
7
+
8
+ im_paths = glob.glob('./img/*/*')
9
+
10
+ random.shuffle(im_paths)
11
+
12
+ for im_path in tqdm(im_paths):
13
+ # im_path = './img/1/23-180_5-y4_Page_034-wVO0001-L1_3-T_6600-I_5535.png'
14
+ if '-L1_' in im_path:
15
+ alb_path = im_path.split('-L1_')[0].replace('img/','alb/') + '.png'
16
+ else:
17
+ alb_path = im_path.split('-L2_')[0].replace('img/','alb/') + '.png'
18
+
19
+ if not os.path.exists(alb_path):
20
+ print(im_path)
21
+ print(alb_path)
22
+
23
+ im = cv2.imread(im_path)
24
+ alb = cv2.imread(alb_path)
25
+ _, mask = cv2.threshold(cv2.cvtColor(alb,cv2.COLOR_BGR2GRAY), 1, 255, cv2.THRESH_BINARY)
26
+
27
+
28
+ ## clean
29
+ # std = np.max(np.std(alb,axis=-1))
30
+ # print(std)
31
+ im_min = np.min(im,axis=-1)
32
+ kernel = np.ones((3,3))
33
+ mask_erode = cv2.dilate(mask,kernel=kernel)
34
+ mask_erode = cv2.erode(mask_erode,kernel=kernel)
35
+ mask_erode = cv2.erode(mask_erode,iterations=4,kernel=kernel)
36
+ metric = np.min(im_min[mask_erode==255])
37
+ metric_num = 0
38
+ if metric==0 or metric==1:
39
+ metric_num = np.sum(im_min[mask_erode==255]==metric)
40
+ if metric_num>=20:
41
+ alb_temp = alb.astype(np.float64)
42
+ alb_temp[alb_temp==0] = alb_temp[alb_temp==0]+1e-5
43
+ shadow = np.clip(im.astype(np.float64)/alb_temp,0,1)
44
+ shadow = (shadow*255).astype(np.uint8)
45
+
46
+ shadow_path = im_path.replace('img/','temp/')
47
+ cv2.imwrite(shadow_path,shadow)
48
+ continue
49
+
50
+
51
+ alb_temp = alb.astype(np.float64)
52
+ alb_temp[alb_temp==0] = alb_temp[alb_temp==0]+1e-5
53
+ shadow = np.clip(im.astype(np.float64)/alb_temp,0,1)
54
+ shadow = (shadow*255).astype(np.uint8)
55
+
56
+ shadow_path = im_path.replace('img/','shadow/')
57
+ cv2.imwrite(shadow_path,shadow)
58
+
59
+ mask_path = im_path.replace('img/','mask/')
60
+ cv2.imwrite(mask_path,mask)
61
+
62
+ # cv2.imshow('im',im)
63
+ # cv2.imshow('alb',alb)
64
+ # cv2.imshow('shadow',shadow)
65
+ # cv2.imshow('mask_erode',mask_erode)
66
+ # print(im_min[mask_erode==255])
67
+ # print(metric,metric_num)
68
+ # cv2.waitKey(0)
eval.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import utils
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from skimage.metrics import structural_similarity,peak_signal_noise_ratio
9
+
10
+ import torch
11
+
12
+ from utils import convert_state_dict
13
+ from models import restormer_arch
14
+ from data.preprocess.crop_merge_image import stride_integral
15
+
16
+ os.sys.path.append('./data/MBD/')
17
+ from data.MBD.infer import net1_net2_infer_single_im
18
+
19
+
20
+ def dewarp_prompt(img):
21
+ mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
22
+ base_coord = utils.getBasecoord(256,256)/256
23
+ img[mask==0]=0
24
+ mask = cv2.resize(mask,(256,256))/255
25
+ return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
26
+
27
+ def deshadow_prompt(img):
28
+ h,w = img.shape[:2]
29
+ # img = cv2.resize(img,(128,128))
30
+ img = cv2.resize(img,(1024,1024))
31
+ rgb_planes = cv2.split(img)
32
+ result_planes = []
33
+ result_norm_planes = []
34
+ bg_imgs = []
35
+ for plane in rgb_planes:
36
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
37
+ bg_img = cv2.medianBlur(dilated_img, 21)
38
+ bg_imgs.append(bg_img)
39
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
40
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
41
+ result_planes.append(diff_img)
42
+ result_norm_planes.append(norm_img)
43
+ bg_imgs = cv2.merge(bg_imgs)
44
+ bg_imgs = cv2.resize(bg_imgs,(w,h))
45
+ # result = cv2.merge(result_planes)
46
+ result_norm = cv2.merge(result_norm_planes)
47
+ result_norm[result_norm==0]=1
48
+ shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
49
+ shadow_map = cv2.resize(shadow_map,(w,h))
50
+ shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
51
+ shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
52
+ # return shadow_map
53
+ return bg_imgs
54
+
55
+ def deblur_prompt(img):
56
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
57
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
58
+ absX = cv2.convertScaleAbs(x) # 转回uint8
59
+ absY = cv2.convertScaleAbs(y)
60
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
61
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
62
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
63
+ return high_frequency
64
+
65
+ def appearance_prompt(img):
66
+ h,w = img.shape[:2]
67
+ # img = cv2.resize(img,(128,128))
68
+ img = cv2.resize(img,(1024,1024))
69
+ rgb_planes = cv2.split(img)
70
+ result_planes = []
71
+ result_norm_planes = []
72
+ for plane in rgb_planes:
73
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
74
+ bg_img = cv2.medianBlur(dilated_img, 21)
75
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
76
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
77
+ result_planes.append(diff_img)
78
+ result_norm_planes.append(norm_img)
79
+ result_norm = cv2.merge(result_norm_planes)
80
+ result_norm = cv2.resize(result_norm,(w,h))
81
+ return result_norm
82
+
83
+ def binarization_promptv2(img):
84
+ result,thresh = utils.SauvolaModBinarization(img)
85
+ thresh = thresh.astype(np.uint8)
86
+ result[result>155]=255
87
+ result[result<=155]=0
88
+
89
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
90
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
91
+ absX = cv2.convertScaleAbs(x) # 转回uint8
92
+ absY = cv2.convertScaleAbs(y)
93
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
94
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
95
+ return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
96
+
97
+ def dewarping(model,im_path):
98
+ INPUT_SIZE=256
99
+ im_org = cv2.imread(im_path)
100
+ im_masked, prompt_org = dewarp_prompt(im_org.copy())
101
+
102
+ h,w = im_masked.shape[:2]
103
+ im_masked = im_masked.copy()
104
+ im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
105
+ im_masked = im_masked / 255.0
106
+ im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
107
+ im_masked = im_masked.float().to(DEVICE)
108
+
109
+ prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
110
+ prompt = prompt.float().to(DEVICE)
111
+
112
+ in_im = torch.cat((im_masked,prompt),dim=1)
113
+
114
+ # inference
115
+ base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
116
+ model = model.float()
117
+ with torch.no_grad():
118
+ pred = model(in_im)
119
+ pred = pred[0][:2].permute(1,2,0).cpu().numpy()
120
+ pred = pred+base_coord
121
+ ## smooth
122
+ for i in range(15):
123
+ pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
124
+ pred = cv2.resize(pred,(w,h))*(w,h)
125
+ pred = pred.astype(np.float32)
126
+ out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
127
+
128
+ prompt_org = (prompt_org*255).astype(np.uint8)
129
+ prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
130
+
131
+ return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
132
+
133
+ def appearance(model,im_path):
134
+ MAX_SIZE=1600
135
+ # obtain im and prompt
136
+ im_org = cv2.imread(im_path)
137
+ h,w = im_org.shape[:2]
138
+ prompt = appearance_prompt(im_org)
139
+ in_im = np.concatenate((im_org,prompt),-1)
140
+
141
+ # constrain the max resolution
142
+ if max(w,h) < MAX_SIZE:
143
+ in_im,padding_h,padding_w = stride_integral(in_im,8)
144
+ else:
145
+ in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
146
+
147
+ # normalize
148
+ in_im = in_im / 255.0
149
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
150
+
151
+ # inference
152
+ in_im = in_im.half().to(DEVICE)
153
+ model = model.half()
154
+ with torch.no_grad():
155
+ pred = model(in_im)
156
+ pred = torch.clamp(pred,0,1)
157
+ pred = pred[0].permute(1,2,0).cpu().numpy()
158
+ pred = (pred*255).astype(np.uint8)
159
+
160
+ if max(w,h) < MAX_SIZE:
161
+ out_im = pred[padding_h:,padding_w:]
162
+ else:
163
+ pred[pred==0] = 1
164
+ shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
165
+ shadow_map = cv2.resize(shadow_map,(w,h))
166
+ shadow_map[shadow_map==0]=0.00001
167
+ out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
168
+
169
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
170
+
171
+
172
+ def deshadowing(model,im_path):
173
+ MAX_SIZE=1600
174
+ # obtain im and prompt
175
+ im_org = cv2.imread(im_path)
176
+ h,w = im_org.shape[:2]
177
+ prompt = deshadow_prompt(im_org)
178
+ in_im = np.concatenate((im_org,prompt),-1)
179
+
180
+ # constrain the max resolution
181
+ if max(w,h) < MAX_SIZE:
182
+ in_im,padding_h,padding_w = stride_integral(in_im,8)
183
+ else:
184
+ in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
185
+
186
+ # normalize
187
+ in_im = in_im / 255.0
188
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
189
+
190
+ # inference
191
+ in_im = in_im.half().to(DEVICE)
192
+ model = model.half()
193
+ with torch.no_grad():
194
+ pred = model(in_im)
195
+ pred = torch.clamp(pred,0,1)
196
+ pred = pred[0].permute(1,2,0).cpu().numpy()
197
+ pred = (pred*255).astype(np.uint8)
198
+
199
+ if max(w,h) < MAX_SIZE:
200
+ out_im = pred[padding_h:,padding_w:]
201
+ else:
202
+ pred[pred==0]=1
203
+ shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
204
+ shadow_map = cv2.resize(shadow_map,(w,h))
205
+ shadow_map[shadow_map==0]=0.00001
206
+ out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
207
+
208
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
209
+
210
+
211
+ def deblurring(model,im_path):
212
+ # setup image
213
+ im_org = cv2.imread(im_path)
214
+ in_im,padding_h,padding_w = stride_integral(im_org,8)
215
+ prompt = deblur_prompt(in_im)
216
+ in_im = np.concatenate((in_im,prompt),-1)
217
+ in_im = in_im / 255.0
218
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
219
+ in_im = in_im.half().to(DEVICE)
220
+ # inference
221
+ model.to(DEVICE)
222
+ model.eval()
223
+ model = model.half()
224
+ with torch.no_grad():
225
+ pred = model(in_im)
226
+ pred = torch.clamp(pred,0,1)
227
+ pred = pred[0].permute(1,2,0).cpu().numpy()
228
+ pred = (pred*255).astype(np.uint8)
229
+ out_im = pred[padding_h:,padding_w:]
230
+
231
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
232
+
233
+
234
+
235
+ def binarization(model,im_path):
236
+ im_org = cv2.imread(im_path)
237
+ im,padding_h,padding_w = stride_integral(im_org,8)
238
+ prompt = binarization_promptv2(im)
239
+ h,w = im.shape[:2]
240
+ in_im = np.concatenate((im,prompt),-1)
241
+
242
+ in_im = in_im / 255.0
243
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
244
+ in_im = in_im.to(DEVICE)
245
+ model = model.half()
246
+ in_im = in_im.half()
247
+ with torch.no_grad():
248
+ pred = model(in_im,'binarization')
249
+ pred = pred[:,:2,:,:]
250
+ pred = torch.max(torch.softmax(pred,1),1)[1]
251
+ pred = pred[0].cpu().numpy()
252
+ pred = (pred*255).astype(np.uint8)
253
+ pred = cv2.resize(pred,(w,h))
254
+ out_im = pred[padding_h:,padding_w:]
255
+
256
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
257
+
258
+
259
+
260
+
261
+
262
+ def get_args():
263
+ parser = argparse.ArgumentParser(description='Params')
264
+ parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
265
+ parser.add_argument('--dataset', nargs='?', type=str, default='./distorted/',help='Path of input document image')
266
+ args = parser.parse_args()
267
+ assert args.dataset in all_datasets.keys(), 'Unregisted dataset, dataset must be one of '+', '.join(all_datasets)
268
+ return args
269
+
270
+ def model_init(args):
271
+ # prepare model
272
+ model = restormer_arch.Restormer(
273
+ inp_channels=6,
274
+ out_channels=3,
275
+ dim = 48,
276
+ num_blocks = [2,3,3,4],
277
+ num_refinement_blocks = 4,
278
+ heads = [1,2,4,8],
279
+ ffn_expansion_factor = 2.66,
280
+ bias = False,
281
+ LayerNorm_type = 'WithBias',
282
+ dual_pixel_task = True
283
+ )
284
+
285
+ if DEVICE.type == 'cpu':
286
+ state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
287
+ else:
288
+ state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
289
+ model.load_state_dict(state)
290
+
291
+ model.eval()
292
+ model = model.to(DEVICE)
293
+ return model
294
+
295
+ def inference_one_im(model,im_path,task):
296
+ if task=='dewarping':
297
+ prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
298
+ elif task=='deshadowing':
299
+ prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
300
+ elif task=='appearance':
301
+ prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
302
+ elif task=='deblurring':
303
+ prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
304
+ elif task=='binarization':
305
+ prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
306
+ elif task=='end2end':
307
+ prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
308
+ cv2.imwrite('./temp.jpg',restorted)
309
+ prompt1,prompt2,prompt3,restorted = deshadowing(model,'./temp.jpg')
310
+ cv2.imwrite('./temp.jpg',restorted)
311
+ prompt1,prompt2,prompt3,restorted = appearance(model,'./temp.jpg')
312
+ os.remove('./temp.jpg')
313
+
314
+ return prompt1,prompt2,prompt3,restorted
315
+
316
+
317
+
318
+ if __name__ == '__main__':
319
+ all_datasets = {'dir300':'dewarping','kligler':'deshadowing','jung':'deshadowing','osr':'deshadowing','docunet_docaligner':'appearance','realdae':'appearance','tdd':'deblurring','dibco18':'binarization'}
320
+
321
+ ## model init
322
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
323
+ args = get_args()
324
+ model = model_init(args)
325
+
326
+ ## inference
327
+ print('Predicting')
328
+ task = all_datasets[args.dataset]
329
+ im_paths = glob.glob(os.path.join('./data/eval/',args.dataset,'*_in.*'))
330
+ for im_path in tqdm(im_paths):
331
+ _,_,_,restorted = inference_one_im(model,im_path,task)
332
+ cv2.imwrite(im_path.replace('_in','_docres'),restorted)
333
+
334
+ ## obtain metric
335
+ print('Metric calculating')
336
+ if task == 'dewarping':
337
+ exit()
338
+ elif task=='deshadowing' or task=='appearance' or task=='deblurring':
339
+ psnr = []
340
+ ssim = []
341
+ for im_path in tqdm(im_paths):
342
+ pred = cv2.imread(im_path.replace('_in','_docres'))
343
+ gt = cv2.imread(im_path.replace('_in','_gt'))
344
+ ssim.append(structural_similarity(pred,gt,multichannel=True))
345
+ psnr.append(peak_signal_noise_ratio(pred, gt))
346
+ print(args.dataset)
347
+ print('ssim:',np.mean(ssim))
348
+ print('psnr:',np.mean(psnr))
349
+ elif task=='binarization':
350
+ fmeasures, pfmeasures,psnrs = [],[],[]
351
+ for im_path in tqdm(im_paths):
352
+ pred = cv2.imread(im_path.replace('_in','_docres'))
353
+ gt = cv2.imread(im_path.replace('_in','_gt'))
354
+ pred = cv2.cvtColor(pred,cv2.COLOR_BGR2GRAY)
355
+ gt = cv2.cvtColor(gt,cv2.COLOR_BGR2GRAY)
356
+ pred[pred>155]=255
357
+ pred[pred<=155]=0
358
+ gt[gt>155]=255
359
+ gt[gt<=155]=0
360
+ fmeasure, pfmeasure,psnr,_,_,_ = utils.bin_metric(pred,gt)
361
+ fmeasures.append(fmeasure)
362
+ pfmeasures.append(pfmeasure)
363
+ psnrs.append(psnr)
364
+ print(args.dataset)
365
+ print('fmeasure:',np.mean(fmeasures))
366
+ print('pfmeasure:',np.mean(pfmeasures))
367
+ print('psnr:',np.mean(psnrs))
368
+
369
+
inference.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import utils
4
+ import argparse
5
+ import numpy as np
6
+
7
+ import torch
8
+
9
+ from utils import convert_state_dict
10
+ from models import restormer_arch
11
+ from data.preprocess.crop_merge_image import stride_integral
12
+
13
+ os.sys.path.append('./data/MBD/')
14
+ from data.MBD.infer import net1_net2_infer_single_im
15
+
16
+
17
+ def dewarp_prompt(img):
18
+ mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
19
+ base_coord = utils.getBasecoord(256,256)/256
20
+ img[mask==0]=0
21
+ mask = cv2.resize(mask,(256,256))/255
22
+ return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
23
+
24
+ def deshadow_prompt(img):
25
+ h,w = img.shape[:2]
26
+ # img = cv2.resize(img,(128,128))
27
+ img = cv2.resize(img,(1024,1024))
28
+ rgb_planes = cv2.split(img)
29
+ result_planes = []
30
+ result_norm_planes = []
31
+ bg_imgs = []
32
+ for plane in rgb_planes:
33
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
34
+ bg_img = cv2.medianBlur(dilated_img, 21)
35
+ bg_imgs.append(bg_img)
36
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
37
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
38
+ result_planes.append(diff_img)
39
+ result_norm_planes.append(norm_img)
40
+ bg_imgs = cv2.merge(bg_imgs)
41
+ bg_imgs = cv2.resize(bg_imgs,(w,h))
42
+ # result = cv2.merge(result_planes)
43
+ result_norm = cv2.merge(result_norm_planes)
44
+ result_norm[result_norm==0]=1
45
+ shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
46
+ shadow_map = cv2.resize(shadow_map,(w,h))
47
+ shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
48
+ shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
49
+ # return shadow_map
50
+ return bg_imgs
51
+
52
+ def deblur_prompt(img):
53
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
54
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
55
+ absX = cv2.convertScaleAbs(x) # 转回uint8
56
+ absY = cv2.convertScaleAbs(y)
57
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
58
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
59
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
60
+ return high_frequency
61
+
62
+ def appearance_prompt(img):
63
+ h,w = img.shape[:2]
64
+ # img = cv2.resize(img,(128,128))
65
+ img = cv2.resize(img,(1024,1024))
66
+ rgb_planes = cv2.split(img)
67
+ result_planes = []
68
+ result_norm_planes = []
69
+ for plane in rgb_planes:
70
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
71
+ bg_img = cv2.medianBlur(dilated_img, 21)
72
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
73
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
74
+ result_planes.append(diff_img)
75
+ result_norm_planes.append(norm_img)
76
+ result_norm = cv2.merge(result_norm_planes)
77
+ result_norm = cv2.resize(result_norm,(w,h))
78
+ return result_norm
79
+
80
+ def binarization_promptv2(img):
81
+ result,thresh = utils.SauvolaModBinarization(img)
82
+ thresh = thresh.astype(np.uint8)
83
+ result[result>155]=255
84
+ result[result<=155]=0
85
+
86
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
87
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
88
+ absX = cv2.convertScaleAbs(x) # 转回uint8
89
+ absY = cv2.convertScaleAbs(y)
90
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
91
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
92
+ return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
93
+
94
+ def dewarping(model,im_path):
95
+ INPUT_SIZE=256
96
+ im_org = cv2.imread(im_path)
97
+ im_masked, prompt_org = dewarp_prompt(im_org.copy())
98
+
99
+ h,w = im_masked.shape[:2]
100
+ im_masked = im_masked.copy()
101
+ im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
102
+ im_masked = im_masked / 255.0
103
+ im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
104
+ im_masked = im_masked.float().to(DEVICE)
105
+
106
+ prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
107
+ prompt = prompt.float().to(DEVICE)
108
+
109
+ in_im = torch.cat((im_masked,prompt),dim=1)
110
+
111
+ # inference
112
+ base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
113
+ model = model.float()
114
+ with torch.no_grad():
115
+ pred = model(in_im)
116
+ pred = pred[0][:2].permute(1,2,0).cpu().numpy()
117
+ pred = pred+base_coord
118
+ ## smooth
119
+ for i in range(15):
120
+ pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
121
+ pred = cv2.resize(pred,(w,h))*(w,h)
122
+ pred = pred.astype(np.float32)
123
+ out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
124
+
125
+ prompt_org = (prompt_org*255).astype(np.uint8)
126
+ prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
127
+
128
+ return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
129
+
130
+ def appearance(model,im_path):
131
+ MAX_SIZE=1600
132
+ # obtain im and prompt
133
+ im_org = cv2.imread(im_path)
134
+ h,w = im_org.shape[:2]
135
+ prompt = appearance_prompt(im_org)
136
+ in_im = np.concatenate((im_org,prompt),-1)
137
+
138
+ # constrain the max resolution
139
+ if max(w,h) < MAX_SIZE:
140
+ in_im,padding_h,padding_w = stride_integral(in_im,8)
141
+ else:
142
+ in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
143
+
144
+ # normalize
145
+ in_im = in_im / 255.0
146
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
147
+
148
+ # inference
149
+ in_im = in_im.half().to(DEVICE)
150
+ model = model.half()
151
+ with torch.no_grad():
152
+ pred = model(in_im)
153
+ pred = torch.clamp(pred,0,1)
154
+ pred = pred[0].permute(1,2,0).cpu().numpy()
155
+ pred = (pred*255).astype(np.uint8)
156
+
157
+ if max(w,h) < MAX_SIZE:
158
+ out_im = pred[padding_h:,padding_w:]
159
+ else:
160
+ pred[pred==0] = 1
161
+ shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
162
+ shadow_map = cv2.resize(shadow_map,(w,h))
163
+ shadow_map[shadow_map==0]=0.00001
164
+ out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
165
+
166
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
167
+
168
+
169
+ def deshadowing(model,im_path):
170
+ MAX_SIZE=1600
171
+ # obtain im and prompt
172
+ im_org = cv2.imread(im_path)
173
+ h,w = im_org.shape[:2]
174
+ prompt = deshadow_prompt(im_org)
175
+ in_im = np.concatenate((im_org,prompt),-1)
176
+
177
+ # constrain the max resolution
178
+ if max(w,h) < MAX_SIZE:
179
+ in_im,padding_h,padding_w = stride_integral(in_im,8)
180
+ else:
181
+ in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
182
+
183
+ # normalize
184
+ in_im = in_im / 255.0
185
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
186
+
187
+ # inference
188
+ in_im = in_im.half().to(DEVICE)
189
+ model = model.half()
190
+ with torch.no_grad():
191
+ pred = model(in_im)
192
+ pred = torch.clamp(pred,0,1)
193
+ pred = pred[0].permute(1,2,0).cpu().numpy()
194
+ pred = (pred*255).astype(np.uint8)
195
+
196
+ if max(w,h) < MAX_SIZE:
197
+ out_im = pred[padding_h:,padding_w:]
198
+ else:
199
+ pred[pred==0]=1
200
+ shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
201
+ shadow_map = cv2.resize(shadow_map,(w,h))
202
+ shadow_map[shadow_map==0]=0.00001
203
+ out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
204
+
205
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
206
+
207
+
208
+ def deblurring(model,im_path):
209
+ # setup image
210
+ im_org = cv2.imread(im_path)
211
+ in_im,padding_h,padding_w = stride_integral(im_org,8)
212
+ prompt = deblur_prompt(in_im)
213
+ in_im = np.concatenate((in_im,prompt),-1)
214
+ in_im = in_im / 255.0
215
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
216
+ in_im = in_im.half().to(DEVICE)
217
+ # inference
218
+ model.to(DEVICE)
219
+ model.eval()
220
+ model = model.half()
221
+ with torch.no_grad():
222
+ pred = model(in_im)
223
+ pred = torch.clamp(pred,0,1)
224
+ pred = pred[0].permute(1,2,0).cpu().numpy()
225
+ pred = (pred*255).astype(np.uint8)
226
+ out_im = pred[padding_h:,padding_w:]
227
+
228
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
229
+
230
+
231
+
232
+ def binarization(model,im_path):
233
+ im_org = cv2.imread(im_path)
234
+ im,padding_h,padding_w = stride_integral(im_org,8)
235
+ prompt = binarization_promptv2(im)
236
+ h,w = im.shape[:2]
237
+ in_im = np.concatenate((im,prompt),-1)
238
+
239
+ in_im = in_im / 255.0
240
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
241
+ in_im = in_im.to(DEVICE)
242
+ model = model.half()
243
+ in_im = in_im.half()
244
+ with torch.no_grad():
245
+ pred = model(in_im)
246
+ pred = pred[:,:2,:,:]
247
+ pred = torch.max(torch.softmax(pred,1),1)[1]
248
+ pred = pred[0].cpu().numpy()
249
+ pred = (pred*255).astype(np.uint8)
250
+ pred = cv2.resize(pred,(w,h))
251
+ out_im = pred[padding_h:,padding_w:]
252
+
253
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
254
+
255
+
256
+
257
+
258
+
259
+ def get_args():
260
+ parser = argparse.ArgumentParser(description='Params')
261
+ parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
262
+ parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/',
263
+ help='Path of input document image')
264
+ parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/',
265
+ help='Folder of the output images')
266
+ parser.add_argument('--task', nargs='?', type=str, default='dewarping',
267
+ help='task that need to be executed')
268
+ parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0,
269
+ help='Width of the input image')
270
+ args = parser.parse_args()
271
+ possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
272
+ assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
273
+ return args
274
+
275
+ def model_init(args):
276
+ # prepare model
277
+ model = restormer_arch.Restormer(
278
+ inp_channels=6,
279
+ out_channels=3,
280
+ dim = 48,
281
+ num_blocks = [2,3,3,4],
282
+ num_refinement_blocks = 4,
283
+ heads = [1,2,4,8],
284
+ ffn_expansion_factor = 2.66,
285
+ bias = False,
286
+ LayerNorm_type = 'WithBias',
287
+ dual_pixel_task = True
288
+ )
289
+
290
+ if DEVICE.type == 'cpu':
291
+ state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
292
+ else:
293
+ state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
294
+ model.load_state_dict(state)
295
+
296
+ model.eval()
297
+ model = model.to(DEVICE)
298
+ return model
299
+
300
+ def inference_one_im(model,im_path,task):
301
+ if task=='dewarping':
302
+ prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
303
+ elif task=='deshadowing':
304
+ prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
305
+ elif task=='appearance':
306
+ prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
307
+ elif task=='deblurring':
308
+ prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
309
+ elif task=='binarization':
310
+ prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
311
+ elif task=='end2end':
312
+ prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
313
+ cv2.imwrite('restorted/step1.jpg',restorted)
314
+ prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg')
315
+ cv2.imwrite('restorted/step2.jpg',restorted)
316
+ prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
317
+ # os.remove('restorted/step1.jpg')
318
+ # os.remove('restorted/step2.jpg')
319
+
320
+ return prompt1,prompt2,prompt3,restorted
321
+
322
+
323
+
324
+ if __name__ == '__main__':
325
+ ## model init
326
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
327
+ args = get_args()
328
+ model = model_init(args)
329
+
330
+ ## inference
331
+ prompt1,prompt2,prompt3,restorted = inference_one_im(model,args.im_path,args.task)
332
+
333
+ ## results saving
334
+ im_name = os.path.split(args.im_path)[-1]
335
+ im_format = '.'+im_name.split('.')[-1]
336
+ save_path = os.path.join(args.out_folder,im_name.replace(im_format,'_'+args.task+im_format))
337
+ cv2.imwrite(save_path,restorted)
338
+ if args.save_dtsprompt:
339
+ cv2.imwrite(save_path.replace(im_format,'_prompt1'+im_format),prompt1)
340
+ cv2.imwrite(save_path.replace(im_format,'_prompt2'+im_format),prompt2)
341
+ cv2.imwrite(save_path.replace(im_format,'_prompt3'+im_format),prompt3)
loaders/docres_loader.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join as pjoin
3
+ import collections
4
+ import json
5
+ from numpy.lib.histograms import histogram_bin_edges
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ import random
10
+ import torch.nn.functional as F
11
+ from torch.utils import data
12
+ import glob
13
+
14
+ class DocResTrainDataset(data.Dataset):
15
+ def __init__(self, dataset={}, img_size=512,):
16
+ json_paths = dataset['json_paths']
17
+ self.task = dataset['task']
18
+ self.size = img_size
19
+ self.im_path = dataset['im_path']
20
+
21
+ self.datas = []
22
+ for json_path in json_paths:
23
+ with open(json_path,'r') as f:
24
+ data = json.load(f)
25
+ self.datas += data
26
+
27
+ self.background_paths = glob.glob('/data2/jiaxin/Training_Data/dewarping/doc_3d/background/*/*/*')
28
+ self.shadow_paths = glob.glob('/data2/jiaxin/Training_Data/illumination/doc3dshadow/new_shadow/*/*')
29
+
30
+ def __len__(self):
31
+ return len(self.datas)
32
+
33
+ def __getitem__(self, index):
34
+ data = self.datas[index]
35
+ in_im,gt_im,dtsprompt = self.data_processing(self.task,data)
36
+
37
+ return torch.cat((in_im,dtsprompt),0), gt_im
38
+
39
+ def data_processing(self,task,data):
40
+
41
+ if task=='deblurring':
42
+ ## image prepare
43
+ in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
44
+ gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
45
+ dtsprompt = self.deblur_dtsprompt(in_im)
46
+ ## get prompt
47
+ in_im, gt_im,dtsprompt = self.randomcrop([in_im,gt_im,dtsprompt])
48
+ in_im = self.rgbim_transform(in_im)
49
+ gt_im = self.rgbim_transform(gt_im)
50
+ dtsprompt = self.rgbim_transform(dtsprompt)
51
+ elif task =='dewarping':
52
+ ## image prepare
53
+ in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
54
+ mask = cv2.imread(os.path.join(self.im_path,data['mask_path']))[:,:,0]
55
+ bm = np.load(os.path.join(self.im_path,data['gt_path'])).astype(np.float) #-> 0-448
56
+ bm = cv2.resize(bm,(448,448))
57
+ ## add background
58
+ background = cv2.imread(random.choice(self.background_paths))
59
+ min_length = min(background.shape[:2])
60
+ crop_size = random.randint(int(min_length*0.5),min_length-1)
61
+ shift_y = np.random.randint(0,background.shape[1]-crop_size)
62
+ shift_x = np.random.randint(0,background.shape[0]-crop_size)
63
+ background = background[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
64
+ background = cv2.resize(background,(448,448))
65
+ if np.mean(in_im[mask==0])<10:
66
+ in_im[mask==0]=background[mask==0]
67
+ ## random crop and get prompt
68
+ in_im,mask,bm = self.random_margin_bm(in_im,mask,bm) # bm-> 0-1
69
+ in_im = cv2.resize(in_im,(self.size,self.size))
70
+ mask = cv2.resize(mask,(self.size,self.size))
71
+ mask_aug = self.mask_augment(mask)
72
+ in_im[mask_aug==0]=0
73
+ bm = cv2.resize(bm,(self.size,self.size)) # bm-> 0-1
74
+ bm_shift = (bm*self.size - self.getBasecoord(self.size,self.size))/self.size
75
+ base_coord = self.getBasecoord(self.size,self.size)/self.size
76
+
77
+ in_im = self.rgbim_transform(in_im)
78
+ base_coord = base_coord.transpose(2, 0, 1)
79
+ base_coord = torch.from_numpy(base_coord)
80
+
81
+ bm_shift = bm_shift.transpose(2, 0, 1)
82
+ bm_shift = torch.from_numpy(bm_shift)
83
+
84
+ mask[mask>155] = 255
85
+ mask[mask<=155] = 0
86
+ mask = mask/255
87
+ mask = np.expand_dims(mask,-1)
88
+ mask = mask.transpose(2, 0, 1)
89
+ mask = torch.from_numpy(mask)
90
+
91
+ mask_aug[mask_aug>155] = 255
92
+ mask_aug[mask_aug<=155] = 0
93
+ mask_aug = mask_aug/255
94
+ mask_aug = np.expand_dims(mask_aug,-1)
95
+ mask_aug = mask_aug.transpose(2, 0, 1)
96
+ mask_aug = torch.from_numpy(mask_aug)
97
+
98
+ in_im = in_im
99
+ gt_im = torch.cat((bm_shift,mask),0)
100
+ dtsprompt = torch.cat((base_coord,mask_aug),0)
101
+
102
+ elif task == 'binarization':
103
+ ## image prepare
104
+ in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
105
+ gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
106
+ ## get prompt
107
+ thr = cv2.imread(os.path.join(self.im_path,data['thr_path']))
108
+ bin_map = cv2.imread(os.path.join(self.im_path,data['bin_path']))
109
+ gradient = cv2.imread(os.path.join(self.im_path,data['gradient_path']))
110
+ bin_map[bin_map>155]=255
111
+ bin_map[bin_map<=155]=0
112
+ in_im, gt_im,thr,bin_map,gradient = self.randomcrop([in_im,gt_im,thr,bin_map,gradient])
113
+ in_im = self.randomAugment_binarization(in_im)
114
+ gt_im[gt_im>155]=255
115
+ gt_im[gt_im<=155]=0
116
+ gt_im = gt_im[:,:,0]
117
+ ## transform
118
+ in_im = self.rgbim_transform(in_im)
119
+ thr = self.rgbim_transform(thr)
120
+ gradient = self.rgbim_transform(gradient)
121
+ bin_map = self.rgbim_transform(bin_map)
122
+ gt_im = gt_im.astype(np.float)/255.
123
+ gt_im = torch.from_numpy(gt_im)
124
+ gt_im = gt_im.unsqueeze(0)
125
+ dtsprompt = torch.cat((thr[0].unsqueeze(0),gradient[0].unsqueeze(0),bin_map[0].unsqueeze(0)),0)
126
+ elif task == 'deshadowing':
127
+
128
+ in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
129
+ gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
130
+ shadow_im = self.deshadow_dtsprompt(in_im)
131
+ if 'fsdsrd' in data['in_path']:
132
+ in_im = cv2.resize(in_im,(512,512))
133
+ gt_im = cv2.resize(gt_im,(512,512))
134
+ shadow_im = cv2.resize(shadow_im,(512,512))
135
+ in_im, gt_im,shadow_im = self.randomcrop([in_im,gt_im,shadow_im])
136
+ else:
137
+ in_im, gt_im,shadow_im = self.randomcrop([in_im,gt_im,shadow_im])
138
+ in_im = self.rgbim_transform(in_im)
139
+ gt_im = self.rgbim_transform(gt_im)
140
+ shadow_im = self.rgbim_transform(shadow_im)
141
+ dtsprompt = shadow_im
142
+
143
+ elif task == 'appearance':
144
+ if 'in_path' in data.keys():
145
+ cap_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
146
+ gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
147
+ gt_im,cap_im = self.randomcrop_realdae(gt_im,cap_im)
148
+ cap_im = self.appearance_randomAugmentv1(cap_im)
149
+ enhance_result = self.appearance_dtsprompt(cap_im)
150
+ else:
151
+ gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
152
+ bleed_im = cv2.imread(os.path.join(self.im_path,random.choice(self.datas)['gt_path']))
153
+ bleed_im = cv2.resize(bleed_im,gt_im.shape[:2][::-1])
154
+ gt_im = self.randomcrop([gt_im])[0]
155
+ bleed_im = self.randomcrop([bleed_im])[0]
156
+ cap_im = self.bleed_trough(gt_im,bleed_im)
157
+
158
+ shadow_path = random.choice(self.shadow_paths)
159
+ shadow_im = cv2.imread(shadow_path)
160
+ cap_im = self.appearance_randomAugmentv2(cap_im,shadow_im)
161
+ enhance_result = self.appearance_dtsprompt(cap_im)
162
+
163
+
164
+ in_im = self.rgbim_transform(cap_im)
165
+ gt_im = self.rgbim_transform(gt_im)
166
+ dtsprompt = self.rgbim_transform(enhance_result)
167
+
168
+ return in_im, gt_im,dtsprompt
169
+
170
+ def randomcrop(self,im_list):
171
+ im_num = len(im_list)
172
+ ## random scale rotate
173
+ if random.uniform(0,1) <= 0.8:
174
+ y,x = im_list[0].shape[:2]
175
+ angle = random.uniform(-180,180)
176
+ scale = random.uniform(0.7,1.5)
177
+ M = cv2.getRotationMatrix2D((int(x/2),int(y/2)),angle,scale)
178
+ for i in range(im_num):
179
+ im_list[i] = cv2.warpAffine(im_list[i],M,(x,y),borderValue=(255,255,255))
180
+
181
+ ## random crop
182
+ crop_size = self.size
183
+ for i in range(im_num):
184
+ h,w = im_list[i].shape[:2]
185
+ h = max(h,crop_size)
186
+ w = max(w,crop_size)
187
+ im_list[i] = cv2.resize(im_list[i],(w,h))
188
+
189
+ if h==crop_size:
190
+ shift_y=0
191
+ else:
192
+ shift_y = np.random.randint(0,h-crop_size)
193
+ if w==crop_size:
194
+ shift_x=0
195
+ else:
196
+ shift_x = np.random.randint(0,w-crop_size)
197
+ for i in range(im_num):
198
+ im_list[i] = im_list[i][shift_y:shift_y+crop_size,shift_x:shift_x+crop_size,:]
199
+ return im_list
200
+
201
+ def deblur_dtsprompt(self,img):
202
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
203
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
204
+ absX = cv2.convertScaleAbs(x) # 转回uint8
205
+ absY = cv2.convertScaleAbs(y)
206
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
207
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
208
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
209
+ return high_frequency
210
+
211
+
212
+ def appearance_dtsprompt(self,img):
213
+ h,w = img.shape[:2]
214
+ img = cv2.resize(img,(1024,1024))
215
+ rgb_planes = cv2.split(img)
216
+ result_planes = []
217
+ result_norm_planes = []
218
+ for plane in rgb_planes:
219
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
220
+ bg_img = cv2.medianBlur(dilated_img, 21)
221
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
222
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
223
+ result_planes.append(diff_img)
224
+ result_norm_planes.append(norm_img)
225
+ result_norm = cv2.merge(result_norm_planes)
226
+ result_norm = cv2.resize(result_norm,(w,h))
227
+ return result_norm
228
+
229
+
230
+ def rgbim_transform(self,im):
231
+ im = im.astype(np.float)/255.
232
+ im = im.transpose(2, 0, 1)
233
+ im = torch.from_numpy(im)
234
+ return im
235
+
236
+
237
+ def random_margin_bm(self,in_im,msk,bm):
238
+ size = in_im.shape[:2]
239
+ [y, x] = (msk).nonzero()
240
+ minx = min(x)
241
+ maxx = max(x)
242
+ miny = min(y)
243
+ maxy = max(y)
244
+
245
+ s = 20
246
+ s = int(20*size[0]/128)
247
+ difference = int(5*size[0]/128)
248
+ cx1 = random.randint(0, s - difference)
249
+ cx2 = random.randint(0, s - difference) + 1
250
+ cy1 = random.randint(0, s - difference)
251
+ cy2 = random.randint(0, s - difference) + 1
252
+
253
+ t = miny-s+cy1
254
+ b = size[0]-maxy-s+cy2
255
+ l = minx-s+cx1
256
+ r = size[1]-maxx-s+cx2
257
+
258
+ t = max(0,t)
259
+ b = max(0,b)
260
+ l = max(0,l)
261
+ r = max(0,r)
262
+
263
+ in_im = in_im[t:size[0]-b,l:size[1]-r]
264
+ msk = msk[t:size[0]-b,l:size[1]-r]
265
+ bm[:,:,1]=bm[:,:,1]-t
266
+ bm[:,:,0]=bm[:,:,0]-l
267
+ bm=bm/np.array([448-l-r, 448-t-b])
268
+
269
+ return in_im,msk,bm
270
+
271
+ def mask_augment(self,mask):
272
+ if random.uniform(0,1) <= 0.6:
273
+ if random.uniform(0,1) <= 0.5:
274
+ mask = cv2.resize(mask,(64,64))
275
+ else:
276
+ mask = cv2.resize(mask,(128,128))
277
+ mask = cv2.resize(mask,(256,256))
278
+ mask[mask>155] = 255
279
+ mask[mask<=155] = 0
280
+ return mask
281
+
282
+ def bleed_trough(self, in_im, bleed_im):
283
+ if random.uniform(0,1) <= 0.5:
284
+ if random.uniform(0,1) <= 0.8:
285
+ ksize = np.random.randint(1,2)*2 + 1
286
+ bleed_im = cv2.blur(bleed_im,(ksize,ksize))
287
+ bleed_im = cv2.flip(bleed_im,1)
288
+ alpha = random.uniform(0.75,1)
289
+ in_im = cv2.addWeighted(in_im,alpha,bleed_im,1-alpha,0)
290
+ return in_im
291
+
292
+ def getBasecoord(self,h,w):
293
+ base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
294
+ base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
295
+ base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
296
+ return base_coord
297
+
298
+
299
+ def randomcrop_realdae(self,gt_im,cap_im):
300
+ if random.uniform(0,1) <= 0.5:
301
+ y,x = gt_im.shape[:2]
302
+ angle = random.uniform(-30,30)
303
+ scale = random.uniform(0.8,1.5)
304
+ M = cv2.getRotationMatrix2D((int(x/2),int(y/2)),angle,scale)
305
+ gt_im = cv2.warpAffine(gt_im,M,(x,y),borderValue=(255,255,255))
306
+ cap_im = cv2.warpAffine(cap_im,M,(x,y),borderValue=(255,255,255))
307
+ crop_size = self.size
308
+ if gt_im.shape[0] <= crop_size:
309
+ gt_im = cv2.copyMakeBorder(gt_im,crop_size-gt_im.shape[0]+1,0,0,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
310
+ cap_im = cv2.copyMakeBorder(cap_im,crop_size-cap_im.shape[0]+1,0,0,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
311
+ if gt_im.shape[1] <= crop_size:
312
+ gt_im = cv2.copyMakeBorder(gt_im,0,0,crop_size-gt_im.shape[1]+1,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
313
+ cap_im = cv2.copyMakeBorder(cap_im,0,0,crop_size-cap_im.shape[1]+1,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
314
+ shift_y = np.random.randint(0,gt_im.shape[1]-crop_size)
315
+ shift_x = np.random.randint(0,gt_im.shape[0]-crop_size)
316
+ gt_im = gt_im[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
317
+ cap_im = cap_im[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
318
+ return gt_im,cap_im
319
+
320
+
321
+ def randomAugment_binarization(self,in_img):
322
+ h,w = in_img.shape[:2]
323
+ ## brightness
324
+ if random.uniform(0,1) <= 0.5:
325
+ high = 1.3
326
+ low = 0.8
327
+ ratio = np.random.uniform(low,high)
328
+ in_img = in_img.astype(np.float64)*ratio
329
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
330
+ ## contrast
331
+ if random.uniform(0,1) <= 0.5:
332
+ high = 1.3
333
+ low = 0.8
334
+ ratio = np.random.uniform(low,high)
335
+ gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
336
+ mean = np.mean(gray)
337
+ mean_array = np.ones_like(in_img).astype(np.float64)*mean
338
+ in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
339
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
340
+ ## color
341
+ if random.uniform(0,1) <= 0.5:
342
+ high = 0.2
343
+ low = 0.1
344
+ ratio = np.random.uniform(0.1,0.3)
345
+ random_color = np.random.randint(50,200,3).reshape(1,1,3)
346
+ random_color = (random_color*ratio).astype(np.uint8)
347
+ random_color = np.tile(random_color,(self.size,self.size,1))
348
+ in_img = in_img.astype(np.float64)*(1-ratio) + random_color
349
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
350
+ return in_img
351
+
352
+
353
+ def deshadow_dtsprompt(self,img):
354
+ h,w = img.shape[:2]
355
+ img = cv2.resize(img,(1024,1024))
356
+ rgb_planes = cv2.split(img)
357
+ result_planes = []
358
+ result_norm_planes = []
359
+ bg_imgs = []
360
+ for plane in rgb_planes:
361
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
362
+ bg_img = cv2.medianBlur(dilated_img, 21)
363
+ bg_imgs.append(bg_img)
364
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
365
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
366
+ result_planes.append(diff_img)
367
+ result_norm_planes.append(norm_img)
368
+ result_norm = cv2.merge(result_norm_planes)
369
+ bg_imgs = cv2.merge(bg_imgs)
370
+ bg_imgs = cv2.resize(bg_imgs,(w,h))
371
+ return bg_imgs
372
+
373
+
374
+
375
+
376
+
377
+
378
+
379
+
380
+
381
+ def randomAugment(self,in_img,gt_img,shadow_img):
382
+ h,w = in_img.shape[:2]
383
+ # random crop
384
+ crop_size = random.randint(128,1024)
385
+ if shadow_img.shape[0] <= crop_size:
386
+ shadow_img = cv2.copyMakeBorder(shadow_img,crop_size-shadow_img.shape[0]+1,0,0,0,borderType=cv2.BORDER_CONSTANT,value=(128,128,128))
387
+ if shadow_img.shape[1] <= crop_size:
388
+ shadow_img = cv2.copyMakeBorder(shadow_img,0,0,crop_size-shadow_img.shape[1]+1,0,borderType=cv2.BORDER_CONSTANT,value=(128,128,128))
389
+
390
+ shift_y = np.random.randint(0,shadow_img.shape[1]-crop_size)
391
+ shift_x = np.random.randint(0,shadow_img.shape[0]-crop_size)
392
+ shadow_img = shadow_img[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
393
+ shadow_img = cv2.resize(shadow_img,(w,h))
394
+ in_img = in_img.astype(np.float64)*(shadow_img.astype(np.float64)+1)/255
395
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
396
+
397
+ ## brightness
398
+ if random.uniform(0,1) <= 0.5:
399
+ high = 1.3
400
+ low = 0.8
401
+ ratio = np.random.uniform(low,high)
402
+ in_img = in_img.astype(np.float64)*ratio
403
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
404
+ ## contrast
405
+ if random.uniform(0,1) <= 0.5:
406
+ high = 1.3
407
+ low = 0.8
408
+ ratio = np.random.uniform(low,high)
409
+ gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
410
+ mean = np.mean(gray)
411
+ mean_array = np.ones_like(in_img).astype(np.float64)*mean
412
+ in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
413
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
414
+ ## color
415
+ if random.uniform(0,1) <= 0.5:
416
+ high = 0.2
417
+ low = 0.1
418
+ ratio = np.random.uniform(0.1,0.3)
419
+ random_color = np.random.randint(50,200,3).reshape(1,1,3)
420
+ random_color = (random_color*ratio).astype(np.uint8)
421
+ random_color = np.tile(random_color,(self.img_size[0],self.img_size[1],1))
422
+ in_img = in_img.astype(np.float64)*(1-ratio) + random_color
423
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
424
+ ## scale and rotate
425
+ if random.uniform(0,1) <= 0:
426
+ y,x = self.img_size
427
+ angle = random.uniform(-180,180)
428
+ scale = random.uniform(0.5,1.5)
429
+ M = cv2.getRotationMatrix2D((int(x/2),int(y/2)),angle,scale)
430
+ in_img = cv2.warpAffine(in_img,M,(x,y),borderValue=0)
431
+ gt_img = cv2.warpAffine(gt_img,M,(x,y),borderValue=0)
432
+ # add noise
433
+ ## jpegcompression
434
+ quanlity_high = 95
435
+ quanlity_low = 45
436
+ quanlity = int(np.random.randint(quanlity_low,quanlity_high))
437
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),quanlity]
438
+ result, encimg = cv2.imencode('.jpg',in_img,encode_param)
439
+ in_img = cv2.imdecode(encimg,1).astype(np.uint8)
440
+ ## gaussiannoise
441
+ mean = 0
442
+ sigma = 0.02
443
+ noise_ratio = 0.004
444
+ num_noise = int(np.ceil(noise_ratio*w))
445
+ coords = [np.random.randint(0,i-1,int(num_noise)) for i in [h,w]]
446
+ gauss = np.random.normal(mean,sigma,num_noise*3)*255
447
+ guass = np.reshape(gauss,(-1,3))
448
+ in_img = in_img.astype(np.float64)
449
+ in_img[tuple(coords)] += guass
450
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
451
+ ## blur
452
+ ksize = np.random.randint(1,2)*2 + 1
453
+ in_img = cv2.blur(in_img,(ksize,ksize))
454
+
455
+ ## erase
456
+ if random.uniform(0,1) <= 0.7:
457
+ for i in range(100):
458
+ area = int(np.random.uniform(0.01,0.05)*h*w)
459
+ ration = np.random.uniform(0.3,1/0.3)
460
+ h_shift = int(np.sqrt(area*ration))
461
+ w_shift = int(np.sqrt(area/ration))
462
+ if (h_shift<h) and (w_shift<w):
463
+ break
464
+ h_start = np.random.randint(0,h-h_shift)
465
+ w_start = np.random.randint(0,w-w_shift)
466
+ randm_area = np.random.randint(low=0,high=255,size=(h_shift,w_shift,3))
467
+ in_img[h_start:h_start+h_shift,w_start:w_start+w_shift,:] = randm_area
468
+
469
+
470
+ return in_img, gt_img
471
+
472
+
473
+ def appearance_randomAugmentv1(self,in_img):
474
+
475
+ ## brightness
476
+ if random.uniform(0,1) <= 0.8:
477
+ high = 1.3
478
+ low = 0.5
479
+ ratio = np.random.uniform(low,high)
480
+ in_img = in_img.astype(np.float64)*ratio
481
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
482
+ ## contrast
483
+ if random.uniform(0,1) <= 0.8:
484
+ high = 1.3
485
+ low = 0.5
486
+ ratio = np.random.uniform(low,high)
487
+ gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
488
+ mean = np.mean(gray)
489
+ mean_array = np.ones_like(in_img).astype(np.float64)*mean
490
+ in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
491
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
492
+ ## color
493
+ if random.uniform(0,1) <= 0.8:
494
+ high = 0.2
495
+ low = 0.1
496
+ ratio = np.random.uniform(0.1,0.3)
497
+ random_color = np.random.randint(50,200,3).reshape(1,1,3)
498
+ random_color = (random_color*ratio).astype(np.uint8)
499
+ random_color = np.tile(random_color,(self.size,self.size,1))
500
+ in_img = in_img.astype(np.float64)*(1-ratio) + random_color
501
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
502
+
503
+ return in_img
504
+
505
+
506
+ def appearance_randomAugmentv2(self,in_img,shadow_img):
507
+ h,w = in_img.shape[:2]
508
+ # random crop
509
+ crop_size = random.randint(96,1024)
510
+ if shadow_img.shape[0] <= crop_size:
511
+ shadow_img = cv2.resize(shadow_img,(crop_size+1,crop_size+1))
512
+ if shadow_img.shape[1] <= crop_size:
513
+ shadow_img = cv2.resize(shadow_img,(crop_size+1,crop_size+1))
514
+
515
+ shift_y = np.random.randint(0,shadow_img.shape[1]-crop_size)
516
+ shift_x = np.random.randint(0,shadow_img.shape[0]-crop_size)
517
+ shadow_img = shadow_img[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
518
+ shadow_img = cv2.resize(shadow_img,(w,h))
519
+ in_img = in_img.astype(np.float64)*(shadow_img.astype(np.float64)+1)/255
520
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
521
+
522
+ ## brightness
523
+ if random.uniform(0,1) <= 0.8:
524
+ high = 1.3
525
+ low = 0.5
526
+ ratio = np.random.uniform(low,high)
527
+ in_img = in_img.astype(np.float64)*ratio
528
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
529
+ ## contrast
530
+ if random.uniform(0,1) <= 0.8:
531
+ high = 1.3
532
+ low = 0.5
533
+ ratio = np.random.uniform(low,high)
534
+ gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
535
+ mean = np.mean(gray)
536
+ mean_array = np.ones_like(in_img).astype(np.float64)*mean
537
+ in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
538
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
539
+ ## color
540
+ if random.uniform(0,1) <= 0.8:
541
+ high = 0.2
542
+ low = 0.1
543
+ ratio = np.random.uniform(0.1,0.3)
544
+ random_color = np.random.randint(50,200,3).reshape(1,1,3)
545
+ random_color = (random_color*ratio).astype(np.uint8)
546
+ random_color = np.tile(random_color,(h,w,1))
547
+ in_img = in_img.astype(np.float64)*(1-ratio) + random_color
548
+ in_img = np.clip(in_img,0,255).astype(np.uint8)
549
+
550
+ if random.uniform(0,1) <= 0.8:
551
+ quanlity_high = 95
552
+ quanlity_low = 45
553
+ quanlity = int(np.random.randint(quanlity_low,quanlity_high))
554
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),quanlity]
555
+ result, encimg = cv2.imencode('.jpg',in_img,encode_param)
556
+ in_img = cv2.imdecode(encimg,1).astype(np.uint8)
557
+
558
+ return in_img
models/restormer_arch.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2
+ ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3
+ ## https://arxiv.org/abs/2111.09881
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from pdb import set_trace as stx
10
+ import numbers
11
+
12
+ from einops import rearrange
13
+
14
+
15
+
16
+ ##########################################################################
17
+ ## Layer Norm
18
+
19
+ def to_3d(x):
20
+ return rearrange(x, 'b c h w -> b (h w) c')
21
+
22
+ def to_4d(x,h,w):
23
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
24
+
25
+ class BiasFree_LayerNorm(nn.Module):
26
+ def __init__(self, normalized_shape):
27
+ super(BiasFree_LayerNorm, self).__init__()
28
+ if isinstance(normalized_shape, numbers.Integral):
29
+ normalized_shape = (normalized_shape,)
30
+ normalized_shape = torch.Size(normalized_shape)
31
+
32
+ assert len(normalized_shape) == 1
33
+
34
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
35
+ self.normalized_shape = normalized_shape
36
+
37
+ def forward(self, x):
38
+ sigma = x.var(-1, keepdim=True, unbiased=False)
39
+ return x / torch.sqrt(sigma+1e-5) * self.weight
40
+
41
+ class WithBias_LayerNorm(nn.Module):
42
+ def __init__(self, normalized_shape):
43
+ super(WithBias_LayerNorm, self).__init__()
44
+ if isinstance(normalized_shape, numbers.Integral):
45
+ normalized_shape = (normalized_shape,)
46
+ normalized_shape = torch.Size(normalized_shape)
47
+
48
+ assert len(normalized_shape) == 1
49
+
50
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
51
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
52
+ self.normalized_shape = normalized_shape
53
+
54
+ def forward(self, x):
55
+ mu = x.mean(-1, keepdim=True)
56
+ sigma = x.var(-1, keepdim=True, unbiased=False)
57
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
58
+
59
+
60
+ class LayerNorm(nn.Module):
61
+ def __init__(self, dim, LayerNorm_type):
62
+ super(LayerNorm, self).__init__()
63
+ if LayerNorm_type =='BiasFree':
64
+ self.body = BiasFree_LayerNorm(dim)
65
+ else:
66
+ self.body = WithBias_LayerNorm(dim)
67
+
68
+ def forward(self, x):
69
+ h, w = x.shape[-2:]
70
+ return to_4d(self.body(to_3d(x)), h, w)
71
+
72
+
73
+
74
+ ##########################################################################
75
+ ## Gated-Dconv Feed-Forward Network (GDFN)
76
+ class FeedForward(nn.Module):
77
+ def __init__(self, dim, ffn_expansion_factor, bias):
78
+ super(FeedForward, self).__init__()
79
+
80
+ hidden_features = int(dim*ffn_expansion_factor)
81
+
82
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
83
+
84
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
85
+
86
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
87
+
88
+ def forward(self, x):
89
+ x = self.project_in(x)
90
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
91
+ x = F.gelu(x1) * x2
92
+ x = self.project_out(x)
93
+ return x
94
+
95
+
96
+
97
+ ##########################################################################
98
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
99
+ class Attention(nn.Module):
100
+ def __init__(self, dim, num_heads, bias):
101
+ super(Attention, self).__init__()
102
+ self.num_heads = num_heads
103
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104
+
105
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
106
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
107
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
+
109
+
110
+
111
+ def forward(self, x):
112
+ b,c,h,w = x.shape
113
+
114
+ qkv = self.qkv_dwconv(self.qkv(x))
115
+ q,k,v = qkv.chunk(3, dim=1)
116
+
117
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
119
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
120
+
121
+ q = torch.nn.functional.normalize(q, dim=-1)
122
+ k = torch.nn.functional.normalize(k, dim=-1)
123
+
124
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
125
+ attn = attn.softmax(dim=-1)
126
+
127
+ out = (attn @ v)
128
+
129
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
130
+
131
+ out = self.project_out(out)
132
+ return out
133
+
134
+
135
+
136
+ ##########################################################################
137
+ class TransformerBlock(nn.Module):
138
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
139
+ super(TransformerBlock, self).__init__()
140
+
141
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
142
+ self.attn = Attention(dim, num_heads, bias)
143
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
144
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
145
+
146
+ def forward(self, x):
147
+ x = x + self.attn(self.norm1(x))
148
+ x = x + self.ffn(self.norm2(x))
149
+
150
+ return x
151
+
152
+
153
+
154
+ ##########################################################################
155
+ ## Overlapped image patch embedding with 3x3 Conv
156
+ class OverlapPatchEmbed(nn.Module):
157
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
158
+ super(OverlapPatchEmbed, self).__init__()
159
+
160
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
161
+
162
+ def forward(self, x):
163
+ x = self.proj(x)
164
+
165
+ return x
166
+
167
+
168
+
169
+ ##########################################################################
170
+ ## Resizing modules
171
+ class Downsample(nn.Module):
172
+ def __init__(self, n_feat):
173
+ super(Downsample, self).__init__()
174
+
175
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
176
+ nn.PixelUnshuffle(2))
177
+
178
+ def forward(self, x):
179
+ return self.body(x)
180
+
181
+ class Upsample(nn.Module):
182
+ def __init__(self, n_feat):
183
+ super(Upsample, self).__init__()
184
+
185
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
186
+ nn.PixelShuffle(2))
187
+
188
+ def forward(self, x):
189
+ return self.body(x)
190
+
191
+ ##########################################################################
192
+ ##---------- Restormer -----------------------
193
+ class Restormer(nn.Module):
194
+ def __init__(self,
195
+ inp_channels=3,
196
+ out_channels=3,
197
+ dim = 48,
198
+ num_blocks = [4,6,6,8],
199
+ num_refinement_blocks = 4,
200
+ heads = [1,2,4,8],
201
+ ffn_expansion_factor = 2.66,
202
+ bias = False,
203
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
204
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
205
+ ):
206
+
207
+ super(Restormer, self).__init__()
208
+
209
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
210
+
211
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
212
+
213
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
214
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
215
+
216
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
217
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
218
+
219
+ self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
220
+ self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
221
+
222
+ self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
223
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
224
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
225
+
226
+
227
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
228
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
229
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
230
+
231
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
232
+
233
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
234
+
235
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
236
+
237
+ #### For Dual-Pixel Defocus Deblurring Task ####
238
+ self.dual_pixel_task = dual_pixel_task
239
+ if self.dual_pixel_task:
240
+ self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
241
+ ###########################
242
+
243
+
244
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
245
+
246
+ def forward(self, inp_img,task=''):
247
+
248
+ inp_enc_level1 = self.patch_embed(inp_img)
249
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
250
+
251
+ inp_enc_level2 = self.down1_2(out_enc_level1)
252
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
253
+
254
+ inp_enc_level3 = self.down2_3(out_enc_level2)
255
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
256
+
257
+ inp_enc_level4 = self.down3_4(out_enc_level3)
258
+ latent = self.latent(inp_enc_level4)
259
+
260
+
261
+ inp_dec_level3 = self.up4_3(latent)
262
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
263
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
264
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
265
+
266
+ inp_dec_level2 = self.up3_2(out_dec_level3)
267
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
268
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
269
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
270
+
271
+ inp_dec_level1 = self.up2_1(out_dec_level2)
272
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
273
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
274
+
275
+ out_dec_level1 = self.refinement(out_dec_level1)
276
+
277
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
278
+ out_dec_level1 = self.output(out_dec_level1)
279
+
280
+ return out_dec_level1
281
+
282
+
283
+
284
+ if __name__ == '__main__':
285
+ from torchtoolbox.tools import summary
286
+ model = Restormer(
287
+ inp_channels=6,
288
+ out_channels=3,
289
+ dim = 48,
290
+ # num_blocks = [4,6,6,8],
291
+ num_blocks = [2,3,3,4],
292
+ num_refinement_blocks = 4,
293
+ heads = [1,2,4,8],
294
+ ffn_expansion_factor = 2.66,
295
+ bias = False,
296
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
297
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
298
+ )
299
+ # model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
300
+ print(summary(model,torch.rand((1, 6, 256, 256))))
301
+
302
+ from thop import profile
303
+ input = torch.rand((1, 6, 256, 256))
304
+ gflops,params = profile(model,inputs=(input,))
305
+ gflops = gflops*2 / 10**9
306
+ params = params / 10**6
307
+ print(gflops,'==============')
308
+ print(params,'==============')
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ numpy==1.21.6
3
+ opencv-python-headless>=4.2.0
4
+ scikit-image>=0.19.3
5
+ torch==1.11.0+cu113
6
+ torchvision==0.12.0+cu113
7
+ einops
8
+ tqdm
9
+ gradio
10
+ Pillow
start_train.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 26413 train.py
train.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import random
5
+ import datetime
6
+ import argparse
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from piq import ssim,psnr
10
+ from itertools import cycle
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils import data
15
+ import torch.distributed as dist
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+
19
+
20
+ from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours
21
+ from loaders import docres_loader
22
+ from models import restormer_arch
23
+
24
+
25
+ def seed_torch(seed=1029):
26
+ random.seed(seed)
27
+ os.environ['PYTHONHASHSEED'] = str(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ torch.backends.cudnn.benchmark = False
33
+ torch.backends.cudnn.deterministic = True
34
+ #torch.use_deterministic_algorithms(True)
35
+ # seed_torch()
36
+
37
+
38
+ def getBasecoord(h,w):
39
+ base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
40
+ base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
41
+ base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
42
+ return base_coord
43
+
44
+ def train(args):
45
+
46
+ ## DDP init
47
+ dist.init_process_group(backend='nccl',init_method='env://',timeout=datetime.timedelta(seconds=36000))
48
+ torch.cuda.set_device(args.local_rank)
49
+ device = torch.device('cuda',args.local_rank)
50
+ torch.cuda.manual_seed_all(42)
51
+
52
+ ### Log file:
53
+ mkdir(args.logdir)
54
+ mkdir(os.path.join(args.logdir,args.experiment_name))
55
+ log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt')
56
+ log_file=open(log_file_path,'a')
57
+ log_file.write('\n--------------- '+args.experiment_name+' ---------------\n')
58
+ log_file.close()
59
+
60
+ ### Setup tensorboard for visualization
61
+ if args.tboard:
62
+ writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name)
63
+
64
+ ### Setup Dataloader
65
+ datasets_setting = [
66
+ {'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']},
67
+ {'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']},
68
+ {'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']},
69
+ {'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']},
70
+ {'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']}
71
+ ]
72
+
73
+
74
+ ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting]
75
+ datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting]
76
+ trainloaders = [{'task':datasets_setting[i],'loader':data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True),'iter_loader':iter(data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True))} for i in range(len(datasets))]
77
+
78
+
79
+ ### test loader
80
+ # for i in tqdm(range(args.total_iter)):
81
+ # loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
82
+ # in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
83
+
84
+
85
+ ### Setup Model
86
+ model = restormer_arch.Restormer(
87
+ inp_channels=6,
88
+ out_channels=3,
89
+ dim = 48,
90
+ num_blocks = [2,3,3,4],
91
+ num_refinement_blocks = 4,
92
+ heads = [1,2,4,8],
93
+ ffn_expansion_factor = 2.66,
94
+ bias = False,
95
+ LayerNorm_type = 'WithBias',
96
+ dual_pixel_task = True
97
+ )
98
+ model=DDP(model.cuda(),device_ids=[args.local_rank],output_device=args.local_rank)
99
+
100
+ ### Optimizer
101
+ optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4)
102
+
103
+ ### LR Scheduler
104
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1)
105
+
106
+ ### load checkpoint
107
+ iter_start=0
108
+ if args.resume is not None:
109
+ print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
110
+ x = checkpoint['model_state']
111
+ model.load_state_dict(x,strict=False)
112
+ iter_start=checkpoint['iter']
113
+ print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start))
114
+
115
+ ###-----------------------------------------Training-----------------------------------------
116
+ ##initialize
117
+ scaler = torch.cuda.amp.GradScaler()
118
+ loss_dict = {}
119
+ total_step = 0
120
+ l2 = nn.MSELoss()
121
+ l1 = nn.L1Loss()
122
+ ce = nn.CrossEntropyLoss()
123
+ bce = nn.BCEWithLogitsLoss()
124
+ m = nn.Sigmoid()
125
+ best = 0
126
+ best_ce = 999
127
+
128
+ ## total_steps
129
+ for iters in range(iter_start,args.total_iter):
130
+ start_time = time.time()
131
+ loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
132
+
133
+ try:
134
+ in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
135
+ except StopIteration:
136
+ trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader'])
137
+ in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
138
+ in_im = in_im.float().cuda()
139
+ gt_im = gt_im.float().cuda()
140
+
141
+ binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0
142
+ with torch.cuda.amp.autocast():
143
+ pred_im = model(in_im,trainloaders[loader_index]['task']['task'])
144
+ if trainloaders[loader_index]['task']['task'] == 'binarization':
145
+ gt_im = gt_im.long()
146
+ binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:])
147
+ loss = binarization_loss
148
+ elif trainloaders[loader_index]['task']['task'] == 'dewarping':
149
+ dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:])
150
+ loss = dewarping_loss
151
+ elif trainloaders[loader_index]['task']['task'] == 'appearance':
152
+ appearance_loss = l1(pred_im, gt_im)
153
+ loss = appearance_loss
154
+ elif trainloaders[loader_index]['task']['task'] == 'deblurring':
155
+ deblurring_loss = l1(pred_im, gt_im)
156
+ loss = deblurring_loss
157
+ elif trainloaders[loader_index]['task']['task'] == 'deshadowing':
158
+ deshadowing_loss = l1(pred_im, gt_im)
159
+ loss = deshadowing_loss
160
+
161
+ optimizer.zero_grad()
162
+ scaler.scale(loss).backward()
163
+ scaler.step(optimizer)
164
+ scaler.update()
165
+
166
+ loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0
167
+ loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0
168
+ loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0
169
+ loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0
170
+ loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0
171
+ end_time = time.time()
172
+ duration = end_time-start_time
173
+ ## log
174
+ if (iters+1) % 10 == 0:
175
+ ## print
176
+ print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters))))
177
+ ## tbord
178
+ if args.tboard:
179
+ for key,value in loss_dict.items():
180
+ writer.add_scalar('Train '+key+'/Iterations', value, total_step)
181
+ ## logfile
182
+ with open(log_file_path,'a') as f:
183
+ f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n')
184
+
185
+
186
+ if (iters+1) % 5000 == 0:
187
+ state = {'iters': iters+1,
188
+ 'model_state': model.state_dict(),
189
+ 'optimizer_state' : optimizer.state_dict(),}
190
+ if not os.path.exists(os.path.join(args.logdir,args.experiment_name)):
191
+ os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name))
192
+ if torch.distributed.get_rank()==0:
193
+ torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1)))
194
+
195
+ sched.step()
196
+
197
+
198
+
199
+ if __name__ == '__main__':
200
+ parser = argparse.ArgumentParser(description='Hyperparams')
201
+ parser.add_argument('--im_size', nargs='?', type=int, default=256,
202
+ help='Height of the input image')
203
+ parser.add_argument('--total_iter', nargs='?', type=int, default=100000,
204
+ help='# of the epochs')
205
+ parser.add_argument('--batch_size', nargs='?', type=int, default=10,
206
+ help='Batch Size')
207
+ parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4,
208
+ help='Learning Rate')
209
+ parser.add_argument('--resume', nargs='?', type=str, default=None,
210
+ help='Path to previous saved model to restart from')
211
+ parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/',
212
+ help='Path to store the loss logs')
213
+ parser.add_argument('--tboard', dest='tboard', action='store_true',
214
+ help='Enable visualization(s) on tensorboard | False by default')
215
+ parser.add_argument('--local_rank',type=int,default=0,metavar='N')
216
+ parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name',
217
+ help='the name of this experiment')
218
+ parser.set_defaults(tboard=False)
219
+ args = parser.parse_args()
220
+
221
+ train(args)
utils.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import os
7
+ from skimage.filters import threshold_sauvola
8
+ import cv2
9
+
10
+ def second2hours(seconds):
11
+ h = seconds//3600
12
+ seconds %= 3600
13
+ m = seconds//60
14
+ seconds %= 60
15
+
16
+ hms = '{:d} H : {:d} Min'.format(int(h),int(m))
17
+ return hms
18
+
19
+
20
+ def dict2string(loss_dict):
21
+ loss_string = ''
22
+ for key, value in loss_dict.items():
23
+ loss_string += key+' {:.4f}, '.format(value)
24
+ return loss_string[:-2]
25
+ def mkdir(dir):
26
+ if not os.path.exists(dir):
27
+ os.makedirs(dir)
28
+
29
+ def convert_state_dict(state_dict):
30
+ """Converts a state dict saved from a dataParallel module to normal
31
+ module state_dict inplace
32
+ :param state_dict is the loaded DataParallel model_state
33
+
34
+ """
35
+ new_state_dict = OrderedDict()
36
+ for k, v in state_dict.items():
37
+ name = k[7:] # remove `module.`
38
+ new_state_dict[name] = v
39
+ return new_state_dict
40
+
41
+
42
+ def get_lr(optimizer):
43
+ for param_group in optimizer.param_groups:
44
+ return float(param_group['lr'])
45
+
46
+
47
+ def torch2cvimg(tensor,min=0,max=1):
48
+ '''
49
+ input:
50
+ tensor -> torch.tensor BxCxHxW C can be 1,3
51
+ return
52
+ im -> ndarray uint8 HxWxC
53
+ '''
54
+ im_list = []
55
+ for i in range(tensor.shape[0]):
56
+ im = tensor.detach().cpu().data.numpy()[i]
57
+ im = im.transpose(1,2,0)
58
+ im = np.clip(im,min,max)
59
+ im = ((im-min)/(max-min)*255).astype(np.uint8)
60
+ im_list.append(im)
61
+ return im_list
62
+ def cvimg2torch(img,min=0,max=1):
63
+ '''
64
+ input:
65
+ im -> ndarray uint8 HxWxC
66
+ return
67
+ tensor -> torch.tensor BxCxHxW
68
+ '''
69
+ img = img.astype(float) / 255.0
70
+ img = img.transpose(2, 0, 1) # NHWC -> NCHW
71
+ img = np.expand_dims(img, 0)
72
+ img = torch.from_numpy(img).float()
73
+ return img
74
+
75
+
76
+ def setup_seed(seed):
77
+ # np.random.seed(seed)
78
+ # random.seed(seed)
79
+ # torch.manual_seed(seed) #cpu
80
+ # torch.cuda.manual_seed_all(seed) #并行gpu
81
+ torch.backends.cudnn.deterministic = True #cpu/gpu结果一致
82
+ # torch.backends.cudnn.benchmark = False #训练集变化不大时使训练加速
83
+
84
+ def SauvolaModBinarization(image,n1=51,n2=51,k1=0.3,k2=0.3,default=True):
85
+ '''
86
+ Binarization using Sauvola's algorithm
87
+ @name : SauvolaModBinarization
88
+ parameters
89
+ @param image (numpy array of shape (3/1) of type np.uint8): color or gray scale image
90
+ optional parameters
91
+ @param n1 (int) : window size for running sauvola during the first pass
92
+ @param n2 (int): window size for running sauvola during the second pass
93
+ @param k1 (float): k value corresponding to sauvola during the first pass
94
+ @param k2 (float): k value corresponding to sauvola during the second pass
95
+ @param default (bool) : bollean variable to set the above parameter as default.
96
+ @param default is set to True : thus default values of the above optional parameters (n1,n2,k1,k2) are set to
97
+ n1 = 5 % of min(image height, image width)
98
+ n2 = 10 % of min(image height, image width)
99
+ k1 = 0.5
100
+ k2 = 0.5
101
+ Returns
102
+ @return A binary image of same size as @param image
103
+
104
+ @cite https://drive.google.com/file/d/1D3CyI5vtodPJeZaD2UV5wdcaIMtkBbdZ/view?usp=sharing
105
+ '''
106
+
107
+ if(default):
108
+ n1 = int(0.05*min(image.shape[0],image.shape[1]))
109
+ if (n1%2==0):
110
+ n1 = n1+1
111
+ n2 = int(0.1*min(image.shape[0],image.shape[1]))
112
+ if (n2%2==0):
113
+ n2 = n2+1
114
+ k1 = 0.5
115
+ k2 = 0.5
116
+ if(image.ndim==3):
117
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
118
+ else:
119
+ gray = np.copy(image)
120
+ T1 = threshold_sauvola(gray, window_size=n1,k=k1)
121
+ max_val = np.amax(gray)
122
+ min_val = np.amin(gray)
123
+ C = np.copy(T1)
124
+ C = C.astype(np.float32)
125
+ C[gray > T1] = (gray[gray > T1] - T1[gray > T1])/(max_val - T1[gray > T1])
126
+ C[gray <= T1] = 0
127
+ C = C * 255.0
128
+ new_in = np.copy(C.astype(np.uint8))
129
+ T2 = threshold_sauvola(new_in, window_size=n2,k=k2)
130
+ binary = np.copy(gray)
131
+ binary[new_in <= T2] = 0
132
+ binary[new_in > T2] = 255
133
+ return binary,T2
134
+
135
+
136
+ def getBasecoord(h,w):
137
+ base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
138
+ base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
139
+ base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
140
+ return base_coord
141
+
142
+
143
+
144
+
145
+
146
+
147
+ import numpy as np
148
+ from scipy import ndimage as ndi
149
+
150
+ # lookup tables for bwmorph_thin
151
+
152
+ G123_LUT = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
153
+ 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
154
+ 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
155
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,
156
+ 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
157
+ 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
158
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
159
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
160
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
161
+ 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0,
162
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
163
+ 0, 0, 0], dtype=np.bool_)
164
+
165
+ G123P_LUT = np.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
166
+ 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
167
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
168
+ 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
169
+ 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
170
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0,
171
+ 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
172
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
173
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0,
174
+ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1,
175
+ 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
176
+ 0, 0, 0], dtype=np.bool_)
177
+
178
+ def bwmorph(image, n_iter=None):
179
+ """
180
+ Perform morphological thinning of a binary image
181
+
182
+ Parameters
183
+ ----------
184
+ image : binary (M, N) ndarray
185
+ The image to be thinned.
186
+
187
+ n_iter : int, number of iterations, optional
188
+ Regardless of the value of this parameter, the thinned image
189
+ is returned immediately if an iteration produces no change.
190
+ If this parameter is specified it thus sets an upper bound on
191
+ the number of iterations performed.
192
+
193
+ Returns
194
+ -------
195
+ out : ndarray of bools
196
+ Thinned image.
197
+
198
+ See also
199
+ --------
200
+ skeletonize
201
+
202
+ Notes
203
+ -----
204
+ This algorithm [1]_ works by making multiple passes over the image,
205
+ removing pixels matching a set of criteria designed to thin
206
+ connected regions while preserving eight-connected components and
207
+ 2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
208
+ correlates the intermediate skeleton image with a neighborhood mask,
209
+ then looks up each neighborhood in a lookup table indicating whether
210
+ the central pixel should be deleted in that sub-iteration.
211
+
212
+ References
213
+ ----------
214
+ .. [1] Z. Guo and R. W. Hall, "Parallel thinning with
215
+ two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
216
+ pp. 359-373, 1989.
217
+ .. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
218
+ Methodologies-A Comprehensive Survey," IEEE Transactions on
219
+ Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
220
+ September 1992, p. 879
221
+
222
+ Examples
223
+ --------
224
+ >>> square = np.zeros((7, 7), dtype=np.uint8)
225
+ >>> square[1:-1, 2:-2] = 1
226
+ >>> square[0,1] = 1
227
+ >>> square
228
+ array([[0, 1, 0, 0, 0, 0, 0],
229
+ [0, 0, 1, 1, 1, 0, 0],
230
+ [0, 0, 1, 1, 1, 0, 0],
231
+ [0, 0, 1, 1, 1, 0, 0],
232
+ [0, 0, 1, 1, 1, 0, 0],
233
+ [0, 0, 1, 1, 1, 0, 0],
234
+ [0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
235
+ >>> skel = bwmorph_thin(square)
236
+ >>> skel.astype(np.uint8)
237
+ array([[0, 1, 0, 0, 0, 0, 0],
238
+ [0, 0, 1, 0, 0, 0, 0],
239
+ [0, 0, 0, 1, 0, 0, 0],
240
+ [0, 0, 0, 1, 0, 0, 0],
241
+ [0, 0, 0, 1, 0, 0, 0],
242
+ [0, 0, 0, 0, 0, 0, 0],
243
+ [0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
244
+ """
245
+ # check parameters
246
+ if n_iter is None:
247
+ n = -1
248
+ elif n_iter <= 0:
249
+ raise ValueError('n_iter must be > 0')
250
+ else:
251
+ n = n_iter
252
+
253
+ # check that we have a 2d binary image, and convert it
254
+ # to uint8
255
+ skel = np.array(image).astype(np.uint8)
256
+
257
+ if skel.ndim != 2:
258
+ raise ValueError('2D array required')
259
+ if not np.all(np.in1d(image.flat,(0,1))):
260
+ raise ValueError('Image contains values other than 0 and 1')
261
+
262
+ # neighborhood mask
263
+ mask = np.array([[ 8, 4, 2],
264
+ [16, 0, 1],
265
+ [32, 64,128]],dtype=np.uint8)
266
+
267
+ # iterate either 1) indefinitely or 2) up to iteration limit
268
+ while n != 0:
269
+ before = np.sum(skel) # count points before thinning
270
+
271
+ # for each subiteration
272
+ for lut in [G123_LUT, G123P_LUT]:
273
+ # correlate image with neighborhood mask
274
+ N = ndi.correlate(skel, mask, mode='constant')
275
+ # take deletion decision from this subiteration's LUT
276
+ D = np.take(lut, N)
277
+ # perform deletion
278
+ skel[D] = 0
279
+
280
+ after = np.sum(skel) # coint points after thinning
281
+
282
+ if before == after:
283
+ # iteration had no effect: finish
284
+ break
285
+
286
+ # count down to iteration limit (or endlessly negative)
287
+ n -= 1
288
+
289
+ return skel.astype(np.bool_)
290
+
291
+ """
292
+ # here's how to make the LUTs
293
+ def nabe(n):
294
+ return np.array([n>>i&1 for i in range(0,9)]).astype(np.bool_)
295
+ def hood(n):
296
+ return np.take(nabe(n), np.array([[3, 2, 1],
297
+ [4, 8, 0],
298
+ [5, 6, 7]]))
299
+ def G1(n):
300
+ s = 0
301
+ bits = nabe(n)
302
+ for i in (0,2,4,6):
303
+ if not(bits[i]) and (bits[i+1] or bits[(i+2) % 8]):
304
+ s += 1
305
+ return s==1
306
+
307
+ g1_lut = np.array([G1(n) for n in range(256)])
308
+ def G2(n):
309
+ n1, n2 = 0, 0
310
+ bits = nabe(n)
311
+ for k in (1,3,5,7):
312
+ if bits[k] or bits[k-1]:
313
+ n1 += 1
314
+ if bits[k] or bits[(k+1) % 8]:
315
+ n2 += 1
316
+ return min(n1,n2) in [2,3]
317
+ g2_lut = np.array([G2(n) for n in range(256)])
318
+ g12_lut = g1_lut & g2_lut
319
+ def G3(n):
320
+ bits = nabe(n)
321
+ return not((bits[1] or bits[2] or not(bits[7])) and bits[0])
322
+ def G3p(n):
323
+ bits = nabe(n)
324
+ return not((bits[5] or bits[6] or not(bits[3])) and bits[4])
325
+ g3_lut = np.array([G3(n) for n in range(256)])
326
+ g3p_lut = np.array([G3p(n) for n in range(256)])
327
+ g123_lut = g12_lut & g3_lut
328
+ g123p_lut = g12_lut & g3p_lut
329
+ """
330
+
331
+ """
332
+ author : Peb Ruswono Aryan
333
+
334
+ metric for evaluating binarization algorithms
335
+ implemented :
336
+
337
+ * F-Measure
338
+ * pseudo F-Measure (as in H-DIBCO 2010 & 2012)
339
+ * Peak Signal to Noise Ratio (PSNR)
340
+ * Negative Rate Measure (NRM)
341
+ * Misclassification Penaltiy Measure (MPM)
342
+ * Distance Reciprocal Distortion (DRD)
343
+
344
+ usage:
345
+ python metric.py test-image.png ground-truth-image.png
346
+ """
347
+
348
+
349
+ def drd_fn(im, im_gt):
350
+ height, width = im.shape
351
+ neg = np.zeros(im.shape)
352
+ neg[im_gt!=im] = 1
353
+ y, x = np.unravel_index(np.flatnonzero(neg), im.shape)
354
+
355
+ n = 2
356
+ m = n*2+1
357
+ W = np.zeros((m,m), dtype=np.uint8)
358
+ W[n,n] = 1.
359
+ W = cv2.distanceTransform(1-W, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
360
+ W[n,n] = 1.
361
+ W = 1./W
362
+ W[n,n] = 0.
363
+ W /= W.sum()
364
+
365
+ nubn = 0.
366
+ block_size = 8
367
+ for y1 in range(0, height, block_size):
368
+ for x1 in range(0, width, block_size):
369
+ y2 = min(y1+block_size-1,height-1)
370
+ x2 = min(x1+block_size-1,width-1)
371
+ block_dim = (x2-x1+1)*(y1-y1+1)
372
+ block = 1-im_gt[y1:y2, x1:x2]
373
+ block_sum = np.sum(block)
374
+ if block_sum>0 and block_sum<block_dim:
375
+ nubn += 1
376
+
377
+ drd_sum= 0.
378
+ tmp = np.zeros(W.shape)
379
+ for i in range(min(1,len(y))):
380
+ tmp[:,:] = 0
381
+
382
+ x1 = max(0, x[i]-n)
383
+ y1 = max(0, y[i]-n)
384
+ x2 = min(width-1, x[i]+n)
385
+ y2 = min(height-1, y[i]+n)
386
+
387
+ yy1 = y1-y[i]+n
388
+ yy2 = y2-y[i]+n
389
+ xx1 = x1-x[i]+n
390
+ xx2 = x2-x[i]+n
391
+
392
+ tmp[yy1:yy2+1,xx1:xx2+1] = np.abs(im[y[i],x[i]]-im_gt[y1:y2+1,x1:x2+1])
393
+ tmp *= W
394
+
395
+ drd_sum += np.sum(tmp)
396
+ return drd_sum/nubn
397
+
398
+ def bin_metric(im,im_gt):
399
+ height, width = im.shape
400
+ npixel = height*width
401
+
402
+ im[im>0] = 1
403
+ gt_mask = im_gt==0
404
+ im_gt[im_gt>0] = 1
405
+
406
+ sk = bwmorph(1-im_gt)
407
+ im_sk = np.ones(im_gt.shape)
408
+ im_sk[sk] = 0
409
+
410
+ kernel = np.ones((3,3), dtype=np.uint8)
411
+ im_dil = cv2.erode(im_gt, kernel)
412
+ im_gtb = im_gt-im_dil
413
+ im_gtbd = cv2.distanceTransform(1-im_gtb, cv2.DIST_L2, 3)
414
+
415
+ nd = im_gtbd.sum()
416
+
417
+ ptp = np.zeros(im_gt.shape)
418
+ ptp[(im==0) & (im_sk==0)] = 1
419
+ numptp = ptp.sum()
420
+
421
+ tp = np.zeros(im_gt.shape)
422
+ tp[(im==0) & (im_gt==0)] = 1
423
+ numtp = tp.sum()
424
+
425
+ tn = np.zeros(im_gt.shape)
426
+ tn[(im==1) & (im_gt==1)] = 1
427
+ numtn = tn.sum()
428
+
429
+ fp = np.zeros(im_gt.shape)
430
+ fp[(im==0) & (im_gt==1)] = 1
431
+ numfp = fp.sum()
432
+
433
+ fn = np.zeros(im_gt.shape)
434
+ fn[(im==1) & (im_gt==0)] = 1
435
+ numfn = fn.sum()
436
+
437
+ precision = numtp / (numtp + numfp)
438
+ recall = numtp / (numtp + numfn)
439
+ precall = numptp / np.sum(1-im_sk)
440
+ fmeasure = (2*recall*precision)/(recall+precision)
441
+ pfmeasure = (2*precall*precision)/(precall+precision)
442
+
443
+ mse = (numfp+numfn)/npixel
444
+ psnr = 10.*np.log10(1./mse)
445
+
446
+ nrfn = numfn / (numfn + numtp)
447
+ nrfp = numfp / (numfp + numtn)
448
+ nrm = (nrfn + nrfp)/2
449
+
450
+ im_dn = im_gtbd.copy()
451
+ im_dn[fn==0] = 0
452
+ dn = np.sum(im_dn)
453
+ mpfn = dn / nd
454
+
455
+ im_dp = im_gtbd.copy()
456
+ im_dp[fp==0] = 0
457
+ dp = np.sum(im_dp)
458
+ mpfp = dp / nd
459
+
460
+ mpm = (mpfp + mpfn) / 2
461
+ drd = drd_fn(im, im_gt)
462
+
463
+ return fmeasure, pfmeasure,psnr,nrm, mpm,drd
464
+ # print("F-measure\t: {0}\npF-measure\t: {1}\nPSNR\t\t: {2}\nNRM\t\t: {3}\nMPM\t\t: {4}\nDRD\t\t: {5}".format(fmeasure, pfmeasure, psnr, nrm, mpm, drd))