sunshineatnoon commited on
Commit
1b2a9b1
1 Parent(s): 1d90a68

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. data/___init__.py +0 -0
  2. data/color150.mat +0 -0
  3. data/images/108073.jpg +0 -0
  4. data/images/12003.jpg +0 -0
  5. data/images/12074.jpg +0 -0
  6. data/images/134008.jpg +0 -0
  7. data/images/134052.jpg +0 -0
  8. data/images/138032.jpg +0 -0
  9. data/images/145053.jpg +0 -0
  10. data/images/164074.jpg +0 -0
  11. data/images/169012.jpg +0 -0
  12. data/images/198023.jpg +0 -0
  13. data/images/25098.jpg +0 -0
  14. data/images/277095.jpg +0 -0
  15. data/images/45077.jpg +0 -0
  16. data/palette.txt +256 -0
  17. data/test_images/100039.jpg +0 -0
  18. data/test_images/108004.jpg +0 -0
  19. data/test_images/130014.jpg +0 -0
  20. data/test_images/130066.jpg +0 -0
  21. data/test_images/16068.jpg +0 -0
  22. data/test_images/2018.jpg +0 -0
  23. data/test_images/208078.jpg +0 -0
  24. data/test_images/223060.jpg +0 -0
  25. data/test_images/226033.jpg +0 -0
  26. data/test_images/388006.jpg +0 -0
  27. data/test_images/78098.jpg +0 -0
  28. libs/__init__.py +0 -0
  29. libs/__pycache__/__init__.cpython-37.pyc +0 -0
  30. libs/__pycache__/__init__.cpython-38.pyc +0 -0
  31. libs/__pycache__/flow_transforms.cpython-37.pyc +0 -0
  32. libs/__pycache__/flow_transforms.cpython-38.pyc +0 -0
  33. libs/__pycache__/nnutils.cpython-37.pyc +0 -0
  34. libs/__pycache__/nnutils.cpython-38.pyc +0 -0
  35. libs/__pycache__/options.cpython-37.pyc +0 -0
  36. libs/__pycache__/options.cpython-38.pyc +0 -0
  37. libs/__pycache__/test_base.cpython-37.pyc +0 -0
  38. libs/__pycache__/test_base.cpython-38.pyc +0 -0
  39. libs/__pycache__/utils.cpython-37.pyc +0 -0
  40. libs/__pycache__/utils.cpython-38.pyc +0 -0
  41. libs/blocks.py +739 -0
  42. libs/custom_transform.py +249 -0
  43. libs/data_coco_stuff.py +166 -0
  44. libs/data_coco_stuff_geo_pho.py +145 -0
  45. libs/data_geo.py +176 -0
  46. libs/data_geo_pho.py +130 -0
  47. libs/data_slic.py +175 -0
  48. libs/discriminator.py +60 -0
  49. libs/flow_transforms.py +393 -0
  50. libs/losses.py +416 -0
data/___init__.py ADDED
File without changes
data/color150.mat ADDED
Binary file (502 Bytes). View file
 
data/images/108073.jpg ADDED
data/images/12003.jpg ADDED
data/images/12074.jpg ADDED
data/images/134008.jpg ADDED
data/images/134052.jpg ADDED
data/images/138032.jpg ADDED
data/images/145053.jpg ADDED
data/images/164074.jpg ADDED
data/images/169012.jpg ADDED
data/images/198023.jpg ADDED
data/images/25098.jpg ADDED
data/images/277095.jpg ADDED
data/images/45077.jpg ADDED
data/palette.txt ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 0 0
2
+ 128 0 0
3
+ 0 128 0
4
+ 128 128 0
5
+ 0 0 128
6
+ 128 0 128
7
+ 0 128 128
8
+ 128 128 128
9
+ 64 0 0
10
+ 191 0 0
11
+ 64 128 0
12
+ 191 128 0
13
+ 64 0 128
14
+ 191 0 128
15
+ 64 128 128
16
+ 191 128 128
17
+ 0 64 0
18
+ 128 64 0
19
+ 0 191 0
20
+ 128 191 0
21
+ 0 64 128
22
+ 128 64 128
23
+ 22 22 22
24
+ 23 23 23
25
+ 24 24 24
26
+ 25 25 25
27
+ 26 26 26
28
+ 27 27 27
29
+ 28 28 28
30
+ 29 29 29
31
+ 30 30 30
32
+ 31 31 31
33
+ 32 32 32
34
+ 33 33 33
35
+ 34 34 34
36
+ 35 35 35
37
+ 36 36 36
38
+ 37 37 37
39
+ 38 38 38
40
+ 39 39 39
41
+ 40 40 40
42
+ 41 41 41
43
+ 42 42 42
44
+ 43 43 43
45
+ 44 44 44
46
+ 45 45 45
47
+ 46 46 46
48
+ 47 47 47
49
+ 48 48 48
50
+ 49 49 49
51
+ 50 50 50
52
+ 51 51 51
53
+ 52 52 52
54
+ 53 53 53
55
+ 54 54 54
56
+ 55 55 55
57
+ 56 56 56
58
+ 57 57 57
59
+ 58 58 58
60
+ 59 59 59
61
+ 60 60 60
62
+ 61 61 61
63
+ 62 62 62
64
+ 63 63 63
65
+ 64 64 64
66
+ 65 65 65
67
+ 66 66 66
68
+ 67 67 67
69
+ 68 68 68
70
+ 69 69 69
71
+ 70 70 70
72
+ 71 71 71
73
+ 72 72 72
74
+ 73 73 73
75
+ 74 74 74
76
+ 75 75 75
77
+ 76 76 76
78
+ 77 77 77
79
+ 78 78 78
80
+ 79 79 79
81
+ 80 80 80
82
+ 81 81 81
83
+ 82 82 82
84
+ 83 83 83
85
+ 84 84 84
86
+ 85 85 85
87
+ 86 86 86
88
+ 87 87 87
89
+ 88 88 88
90
+ 89 89 89
91
+ 90 90 90
92
+ 91 91 91
93
+ 92 92 92
94
+ 93 93 93
95
+ 94 94 94
96
+ 95 95 95
97
+ 96 96 96
98
+ 97 97 97
99
+ 98 98 98
100
+ 99 99 99
101
+ 100 100 100
102
+ 101 101 101
103
+ 102 102 102
104
+ 103 103 103
105
+ 104 104 104
106
+ 105 105 105
107
+ 106 106 106
108
+ 107 107 107
109
+ 108 108 108
110
+ 109 109 109
111
+ 110 110 110
112
+ 111 111 111
113
+ 112 112 112
114
+ 113 113 113
115
+ 114 114 114
116
+ 115 115 115
117
+ 116 116 116
118
+ 117 117 117
119
+ 118 118 118
120
+ 119 119 119
121
+ 120 120 120
122
+ 121 121 121
123
+ 122 122 122
124
+ 123 123 123
125
+ 124 124 124
126
+ 125 125 125
127
+ 126 126 126
128
+ 127 127 127
129
+ 128 128 128
130
+ 129 129 129
131
+ 130 130 130
132
+ 131 131 131
133
+ 132 132 132
134
+ 133 133 133
135
+ 134 134 134
136
+ 135 135 135
137
+ 136 136 136
138
+ 137 137 137
139
+ 138 138 138
140
+ 139 139 139
141
+ 140 140 140
142
+ 141 141 141
143
+ 142 142 142
144
+ 143 143 143
145
+ 144 144 144
146
+ 145 145 145
147
+ 146 146 146
148
+ 147 147 147
149
+ 148 148 148
150
+ 149 149 149
151
+ 150 150 150
152
+ 151 151 151
153
+ 152 152 152
154
+ 153 153 153
155
+ 154 154 154
156
+ 155 155 155
157
+ 156 156 156
158
+ 157 157 157
159
+ 158 158 158
160
+ 159 159 159
161
+ 160 160 160
162
+ 161 161 161
163
+ 162 162 162
164
+ 163 163 163
165
+ 164 164 164
166
+ 165 165 165
167
+ 166 166 166
168
+ 167 167 167
169
+ 168 168 168
170
+ 169 169 169
171
+ 170 170 170
172
+ 171 171 171
173
+ 172 172 172
174
+ 173 173 173
175
+ 174 174 174
176
+ 175 175 175
177
+ 176 176 176
178
+ 177 177 177
179
+ 178 178 178
180
+ 179 179 179
181
+ 180 180 180
182
+ 181 181 181
183
+ 182 182 182
184
+ 183 183 183
185
+ 184 184 184
186
+ 185 185 185
187
+ 186 186 186
188
+ 187 187 187
189
+ 188 188 188
190
+ 189 189 189
191
+ 190 190 190
192
+ 191 191 191
193
+ 192 192 192
194
+ 193 193 193
195
+ 194 194 194
196
+ 195 195 195
197
+ 196 196 196
198
+ 197 197 197
199
+ 198 198 198
200
+ 199 199 199
201
+ 200 200 200
202
+ 201 201 201
203
+ 202 202 202
204
+ 203 203 203
205
+ 204 204 204
206
+ 205 205 205
207
+ 206 206 206
208
+ 207 207 207
209
+ 208 208 208
210
+ 209 209 209
211
+ 210 210 210
212
+ 211 211 211
213
+ 212 212 212
214
+ 213 213 213
215
+ 214 214 214
216
+ 215 215 215
217
+ 216 216 216
218
+ 217 217 217
219
+ 218 218 218
220
+ 219 219 219
221
+ 220 220 220
222
+ 221 221 221
223
+ 222 222 222
224
+ 223 223 223
225
+ 224 224 224
226
+ 225 225 225
227
+ 226 226 226
228
+ 227 227 227
229
+ 228 228 228
230
+ 229 229 229
231
+ 230 230 230
232
+ 231 231 231
233
+ 232 232 232
234
+ 233 233 233
235
+ 234 234 234
236
+ 235 235 235
237
+ 236 236 236
238
+ 237 237 237
239
+ 238 238 238
240
+ 239 239 239
241
+ 240 240 240
242
+ 241 241 241
243
+ 242 242 242
244
+ 243 243 243
245
+ 244 244 244
246
+ 245 245 245
247
+ 246 246 246
248
+ 247 247 247
249
+ 248 248 248
250
+ 249 249 249
251
+ 250 250 250
252
+ 251 251 251
253
+ 252 252 252
254
+ 253 253 253
255
+ 254 254 254
256
+ 255 255 255
data/test_images/100039.jpg ADDED
data/test_images/108004.jpg ADDED
data/test_images/130014.jpg ADDED
data/test_images/130066.jpg ADDED
data/test_images/16068.jpg ADDED
data/test_images/2018.jpg ADDED
data/test_images/208078.jpg ADDED
data/test_images/223060.jpg ADDED
data/test_images/226033.jpg ADDED
data/test_images/388006.jpg ADDED
data/test_images/78098.jpg ADDED
libs/__init__.py ADDED
File without changes
libs/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (151 Bytes). View file
 
libs/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (155 Bytes). View file
 
libs/__pycache__/flow_transforms.cpython-37.pyc ADDED
Binary file (14.1 kB). View file
 
libs/__pycache__/flow_transforms.cpython-38.pyc ADDED
Binary file (13.7 kB). View file
 
libs/__pycache__/nnutils.cpython-37.pyc ADDED
Binary file (3.39 kB). View file
 
libs/__pycache__/nnutils.cpython-38.pyc ADDED
Binary file (3.4 kB). View file
 
libs/__pycache__/options.cpython-37.pyc ADDED
Binary file (5.43 kB). View file
 
libs/__pycache__/options.cpython-38.pyc ADDED
Binary file (5.49 kB). View file
 
libs/__pycache__/test_base.cpython-37.pyc ADDED
Binary file (4.01 kB). View file
 
libs/__pycache__/test_base.cpython-38.pyc ADDED
Binary file (4.07 kB). View file
 
libs/__pycache__/utils.cpython-37.pyc ADDED
Binary file (4.51 kB). View file
 
libs/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.53 kB). View file
 
libs/blocks.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Network Modules
2
+ - encoder3: vgg encoder up to relu31
3
+ - decoder3: mirror decoder to encoder3
4
+ - encoder4: vgg encoder up to relu41
5
+ - decoder4: mirror decoder to encoder4
6
+ - encoder5: vgg encoder up to relu51
7
+ - styleLoss: gram matrix loss for all style layers
8
+ - styleLossMask: gram matrix loss for all style layers, compare between each part defined by a mask
9
+ - GramMatrix: compute gram matrix for one layer
10
+ - LossCriterion: style transfer loss that include both content & style losses
11
+ - LossCriterionMask: style transfer loss that include both content & style losses, use the styleLossMask
12
+ - VQEmbedding: codebook class for VQVAE
13
+ """
14
+ import os
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from .vq_functions import vq, vq_st
19
+ from collections import OrderedDict
20
+
21
+ class MetaModule(nn.Module):
22
+ """
23
+ Base class for PyTorch meta-learning modules. These modules accept an
24
+ additional argument `params` in their `forward` method.
25
+
26
+ Notes
27
+ -----
28
+ Objects inherited from `MetaModule` are fully compatible with PyTorch
29
+ modules from `torch.nn.Module`. The argument `params` is a dictionary of
30
+ tensors, with full support of the computation graph (for differentiation).
31
+ """
32
+ def meta_named_parameters(self, prefix='', recurse=True):
33
+ gen = self._named_members(
34
+ lambda module: module._parameters.items()
35
+ if isinstance(module, MetaModule) else [],
36
+ prefix=prefix, recurse=recurse)
37
+ for elem in gen:
38
+ yield elem
39
+
40
+ def meta_parameters(self, recurse=True):
41
+ for name, param in self.meta_named_parameters(recurse=recurse):
42
+ yield param
43
+
44
+ class BatchLinear(nn.Linear, MetaModule):
45
+ '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
46
+ hypernetwork.'''
47
+ __doc__ = nn.Linear.__doc__
48
+
49
+ def forward(self, input, params=None):
50
+ if params is None:
51
+ params = OrderedDict(self.named_parameters())
52
+
53
+ bias = params.get('bias', None)
54
+ weight = params['weight']
55
+
56
+ output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
57
+ output += bias.unsqueeze(-2)
58
+ return output
59
+
60
+ class decoder1(nn.Module):
61
+ def __init__(self):
62
+ super(decoder1,self).__init__()
63
+ self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1))
64
+ # 226 x 226
65
+ self.conv3 = nn.Conv2d(64,3,3,1,0)
66
+ # 224 x 224
67
+
68
+ def forward(self,x):
69
+ out = self.reflecPad2(x)
70
+ out = self.conv3(out)
71
+ return out
72
+
73
+
74
+ class decoder2(nn.Module):
75
+ def __init__(self):
76
+ super(decoder2,self).__init__()
77
+ # decoder
78
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
79
+ self.conv5 = nn.Conv2d(128,64,3,1,0)
80
+ self.relu5 = nn.ReLU(inplace=True)
81
+ # 112 x 112
82
+
83
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
84
+ # 224 x 224
85
+
86
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
87
+ self.conv6 = nn.Conv2d(64,64,3,1,0)
88
+ self.relu6 = nn.ReLU(inplace=True)
89
+ # 224 x 224
90
+
91
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
92
+ self.conv7 = nn.Conv2d(64,3,3,1,0)
93
+
94
+ def forward(self,x):
95
+ out = self.reflecPad5(x)
96
+ out = self.conv5(out)
97
+ out = self.relu5(out)
98
+ out = self.unpool(out)
99
+ out = self.reflecPad6(out)
100
+ out = self.conv6(out)
101
+ out = self.relu6(out)
102
+ out = self.reflecPad7(out)
103
+ out = self.conv7(out)
104
+ return out
105
+
106
+ class encoder3(nn.Module):
107
+ def __init__(self):
108
+ super(encoder3,self).__init__()
109
+ # vgg
110
+ # 224 x 224
111
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
112
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
113
+ # 226 x 226
114
+
115
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
116
+ self.relu2 = nn.ReLU(inplace=True)
117
+ # 224 x 224
118
+
119
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
120
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
121
+ self.relu3 = nn.ReLU(inplace=True)
122
+ # 224 x 224
123
+
124
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
125
+ # 112 x 112
126
+
127
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
128
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
129
+ self.relu4 = nn.ReLU(inplace=True)
130
+ # 112 x 112
131
+
132
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
133
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
134
+ self.relu5 = nn.ReLU(inplace=True)
135
+ # 112 x 112
136
+
137
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
138
+ # 56 x 56
139
+
140
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
141
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
142
+ self.relu6 = nn.ReLU(inplace=True)
143
+ # 56 x 56
144
+ def forward(self,x):
145
+ out = self.conv1(x)
146
+ out = self.reflecPad1(out)
147
+ out = self.conv2(out)
148
+ out = self.relu2(out)
149
+ out = self.reflecPad3(out)
150
+ out = self.conv3(out)
151
+ pool1 = self.relu3(out)
152
+ out,pool_idx = self.maxPool(pool1)
153
+ out = self.reflecPad4(out)
154
+ out = self.conv4(out)
155
+ out = self.relu4(out)
156
+ out = self.reflecPad5(out)
157
+ out = self.conv5(out)
158
+ pool2 = self.relu5(out)
159
+ out,pool_idx2 = self.maxPool2(pool2)
160
+ out = self.reflecPad6(out)
161
+ out = self.conv6(out)
162
+ out = self.relu6(out)
163
+ return out
164
+
165
+ class decoder3(nn.Module):
166
+ def __init__(self):
167
+ super(decoder3,self).__init__()
168
+ # decoder
169
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
170
+ self.conv7 = nn.Conv2d(256,128,3,1,0)
171
+ self.relu7 = nn.ReLU(inplace=True)
172
+ # 56 x 56
173
+
174
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
175
+ # 112 x 112
176
+
177
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
178
+ self.conv8 = nn.Conv2d(128,128,3,1,0)
179
+ self.relu8 = nn.ReLU(inplace=True)
180
+ # 112 x 112
181
+
182
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
183
+ self.conv9 = nn.Conv2d(128,64,3,1,0)
184
+ self.relu9 = nn.ReLU(inplace=True)
185
+
186
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
187
+ # 224 x 224
188
+
189
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
190
+ self.conv10 = nn.Conv2d(64,64,3,1,0)
191
+ self.relu10 = nn.ReLU(inplace=True)
192
+
193
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
194
+ self.conv11 = nn.Conv2d(64,3,3,1,0)
195
+
196
+ def forward(self,x):
197
+ output = {}
198
+ out = self.reflecPad7(x)
199
+ out = self.conv7(out)
200
+ out = self.relu7(out)
201
+ out = self.unpool(out)
202
+ out = self.reflecPad8(out)
203
+ out = self.conv8(out)
204
+ out = self.relu8(out)
205
+ out = self.reflecPad9(out)
206
+ out = self.conv9(out)
207
+ out_relu9 = self.relu9(out)
208
+ out = self.unpool2(out_relu9)
209
+ out = self.reflecPad10(out)
210
+ out = self.conv10(out)
211
+ out = self.relu10(out)
212
+ out = self.reflecPad11(out)
213
+ out = self.conv11(out)
214
+ return out
215
+
216
+ class encoder4(nn.Module):
217
+ def __init__(self):
218
+ super(encoder4,self).__init__()
219
+ # vgg
220
+ # 224 x 224
221
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
222
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
223
+ # 226 x 226
224
+
225
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
226
+ self.relu2 = nn.ReLU(inplace=True)
227
+ # 224 x 224
228
+
229
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
230
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
231
+ self.relu3 = nn.ReLU(inplace=True)
232
+ # 224 x 224
233
+
234
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
235
+ # 112 x 112
236
+
237
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
238
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
239
+ self.relu4 = nn.ReLU(inplace=True)
240
+ # 112 x 112
241
+
242
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
243
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
244
+ self.relu5 = nn.ReLU(inplace=True)
245
+ # 112 x 112
246
+
247
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
248
+ # 56 x 56
249
+
250
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
251
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
252
+ self.relu6 = nn.ReLU(inplace=True)
253
+ # 56 x 56
254
+
255
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
256
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
257
+ self.relu7 = nn.ReLU(inplace=True)
258
+ # 56 x 56
259
+
260
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
261
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
262
+ self.relu8 = nn.ReLU(inplace=True)
263
+ # 56 x 56
264
+
265
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
266
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
267
+ self.relu9 = nn.ReLU(inplace=True)
268
+ # 56 x 56
269
+
270
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
271
+ # 28 x 28
272
+
273
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
274
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
275
+ self.relu10 = nn.ReLU(inplace=True)
276
+ # 28 x 28
277
+
278
+ def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
279
+ output = {}
280
+ out = self.conv1(x)
281
+ out = self.reflecPad1(out)
282
+ out = self.conv2(out)
283
+ output['r11'] = self.relu2(out)
284
+ out = self.reflecPad7(output['r11'])
285
+
286
+ out = self.conv3(out)
287
+ output['r12'] = self.relu3(out)
288
+
289
+ output['p1'] = self.maxPool(output['r12'])
290
+ out = self.reflecPad4(output['p1'])
291
+ out = self.conv4(out)
292
+ output['r21'] = self.relu4(out)
293
+ out = self.reflecPad7(output['r21'])
294
+
295
+ out = self.conv5(out)
296
+ output['r22'] = self.relu5(out)
297
+
298
+ output['p2'] = self.maxPool2(output['r22'])
299
+ out = self.reflecPad6(output['p2'])
300
+ out = self.conv6(out)
301
+ output['r31'] = self.relu6(out)
302
+ if(matrix31 is not None):
303
+ feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
304
+ out = self.reflecPad7(feature3)
305
+ else:
306
+ out = self.reflecPad7(output['r31'])
307
+ out = self.conv7(out)
308
+ output['r32'] = self.relu7(out)
309
+
310
+ out = self.reflecPad8(output['r32'])
311
+ out = self.conv8(out)
312
+ output['r33'] = self.relu8(out)
313
+
314
+ out = self.reflecPad9(output['r33'])
315
+ out = self.conv9(out)
316
+ output['r34'] = self.relu9(out)
317
+
318
+ output['p3'] = self.maxPool3(output['r34'])
319
+ out = self.reflecPad10(output['p3'])
320
+ out = self.conv10(out)
321
+ output['r41'] = self.relu10(out)
322
+
323
+ return output
324
+
325
+ class decoder4(nn.Module):
326
+ def __init__(self):
327
+ super(decoder4,self).__init__()
328
+ # decoder
329
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
330
+ self.conv11 = nn.Conv2d(512,256,3,1,0)
331
+ self.relu11 = nn.ReLU(inplace=True)
332
+ # 28 x 28
333
+
334
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
335
+ # 56 x 56
336
+
337
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
338
+ self.conv12 = nn.Conv2d(256,256,3,1,0)
339
+ self.relu12 = nn.ReLU(inplace=True)
340
+ # 56 x 56
341
+
342
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
343
+ self.conv13 = nn.Conv2d(256,256,3,1,0)
344
+ self.relu13 = nn.ReLU(inplace=True)
345
+ # 56 x 56
346
+
347
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
348
+ self.conv14 = nn.Conv2d(256,256,3,1,0)
349
+ self.relu14 = nn.ReLU(inplace=True)
350
+ # 56 x 56
351
+
352
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
353
+ self.conv15 = nn.Conv2d(256,128,3,1,0)
354
+ self.relu15 = nn.ReLU(inplace=True)
355
+ # 56 x 56
356
+
357
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
358
+ # 112 x 112
359
+
360
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
361
+ self.conv16 = nn.Conv2d(128,128,3,1,0)
362
+ self.relu16 = nn.ReLU(inplace=True)
363
+ # 112 x 112
364
+
365
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
366
+ self.conv17 = nn.Conv2d(128,64,3,1,0)
367
+ self.relu17 = nn.ReLU(inplace=True)
368
+ # 112 x 112
369
+
370
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
371
+ # 224 x 224
372
+
373
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
374
+ self.conv18 = nn.Conv2d(64,64,3,1,0)
375
+ self.relu18 = nn.ReLU(inplace=True)
376
+ # 224 x 224
377
+
378
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
379
+ self.conv19 = nn.Conv2d(64,3,3,1,0)
380
+
381
+ def forward(self,x):
382
+ # decoder
383
+ out = self.reflecPad11(x)
384
+ out = self.conv11(out)
385
+ out = self.relu11(out)
386
+ out = self.unpool(out)
387
+ out = self.reflecPad12(out)
388
+ out = self.conv12(out)
389
+
390
+ out = self.relu12(out)
391
+ out = self.reflecPad13(out)
392
+ out = self.conv13(out)
393
+ out = self.relu13(out)
394
+ out = self.reflecPad14(out)
395
+ out = self.conv14(out)
396
+ out = self.relu14(out)
397
+ out = self.reflecPad15(out)
398
+ out = self.conv15(out)
399
+ out = self.relu15(out)
400
+ out = self.unpool2(out)
401
+ out = self.reflecPad16(out)
402
+ out = self.conv16(out)
403
+ out = self.relu16(out)
404
+ out = self.reflecPad17(out)
405
+ out = self.conv17(out)
406
+ out = self.relu17(out)
407
+ out = self.unpool3(out)
408
+ out = self.reflecPad18(out)
409
+ out = self.conv18(out)
410
+ out = self.relu18(out)
411
+ out = self.reflecPad19(out)
412
+ out = self.conv19(out)
413
+ return out
414
+
415
+ class encoder5(nn.Module):
416
+ def __init__(self):
417
+ super(encoder5,self).__init__()
418
+ # vgg
419
+ # 224 x 224
420
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
421
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
422
+ # 226 x 226
423
+
424
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
425
+ self.relu2 = nn.ReLU(inplace=True)
426
+ # 224 x 224
427
+
428
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
429
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
430
+ self.relu3 = nn.ReLU(inplace=True)
431
+ # 224 x 224
432
+
433
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
434
+ # 112 x 112
435
+
436
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
437
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
438
+ self.relu4 = nn.ReLU(inplace=True)
439
+ # 112 x 112
440
+
441
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
442
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
443
+ self.relu5 = nn.ReLU(inplace=True)
444
+ # 112 x 112
445
+
446
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
447
+ # 56 x 56
448
+
449
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
450
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
451
+ self.relu6 = nn.ReLU(inplace=True)
452
+ # 56 x 56
453
+
454
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
455
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
456
+ self.relu7 = nn.ReLU(inplace=True)
457
+ # 56 x 56
458
+
459
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
460
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
461
+ self.relu8 = nn.ReLU(inplace=True)
462
+ # 56 x 56
463
+
464
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
465
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
466
+ self.relu9 = nn.ReLU(inplace=True)
467
+ # 56 x 56
468
+
469
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
470
+ # 28 x 28
471
+
472
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
473
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
474
+ self.relu10 = nn.ReLU(inplace=True)
475
+
476
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
477
+ self.conv11 = nn.Conv2d(512,512,3,1,0)
478
+ self.relu11 = nn.ReLU(inplace=True)
479
+
480
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
481
+ self.conv12 = nn.Conv2d(512,512,3,1,0)
482
+ self.relu12 = nn.ReLU(inplace=True)
483
+
484
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
485
+ self.conv13 = nn.Conv2d(512,512,3,1,0)
486
+ self.relu13 = nn.ReLU(inplace=True)
487
+
488
+ self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
489
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
490
+ self.conv14 = nn.Conv2d(512,512,3,1,0)
491
+ self.relu14 = nn.ReLU(inplace=True)
492
+
493
+ def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
494
+ output = {}
495
+ out = self.conv1(x)
496
+ out = self.reflecPad1(out)
497
+ out = self.conv2(out)
498
+ output['r11'] = self.relu2(out)
499
+ out = self.reflecPad7(output['r11'])
500
+
501
+ #out = self.reflecPad3(output['r11'])
502
+ out = self.conv3(out)
503
+ output['r12'] = self.relu3(out)
504
+
505
+ output['p1'] = self.maxPool(output['r12'])
506
+ out = self.reflecPad4(output['p1'])
507
+ out = self.conv4(out)
508
+ output['r21'] = self.relu4(out)
509
+ out = self.reflecPad7(output['r21'])
510
+
511
+ #out = self.reflecPad5(output['r21'])
512
+ out = self.conv5(out)
513
+ output['r22'] = self.relu5(out)
514
+
515
+ output['p2'] = self.maxPool2(output['r22'])
516
+ out = self.reflecPad6(output['p2'])
517
+ out = self.conv6(out)
518
+ output['r31'] = self.relu6(out)
519
+ if(styleV256 is not None):
520
+ feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
521
+ out = self.reflecPad7(feature)
522
+ else:
523
+ out = self.reflecPad7(output['r31'])
524
+ out = self.conv7(out)
525
+ output['r32'] = self.relu7(out)
526
+
527
+ out = self.reflecPad8(output['r32'])
528
+ out = self.conv8(out)
529
+ output['r33'] = self.relu8(out)
530
+
531
+ out = self.reflecPad9(output['r33'])
532
+ out = self.conv9(out)
533
+ output['r34'] = self.relu9(out)
534
+
535
+ output['p3'] = self.maxPool3(output['r34'])
536
+ out = self.reflecPad10(output['p3'])
537
+ out = self.conv10(out)
538
+ output['r41'] = self.relu10(out)
539
+
540
+ out = self.reflecPad11(out)
541
+ out = self.conv11(out)
542
+ out = self.relu11(out)
543
+ out = self.reflecPad12(out)
544
+ out = self.conv12(out)
545
+ out = self.relu12(out)
546
+ out = self.reflecPad13(out)
547
+ out = self.conv13(out)
548
+ out = self.relu13(out)
549
+ out = self.maxPool4(out)
550
+ out = self.reflecPad14(out)
551
+ out = self.conv14(out)
552
+ out = self.relu14(out)
553
+ output['r51'] = out
554
+ return output
555
+
556
+ class styleLoss(nn.Module):
557
+ def forward(self, input, target):
558
+ ib,ic,ih,iw = input.size()
559
+ iF = input.view(ib,ic,-1)
560
+ iMean = torch.mean(iF,dim=2)
561
+ iCov = GramMatrix()(input)
562
+
563
+ tb,tc,th,tw = target.size()
564
+ tF = target.view(tb,tc,-1)
565
+ tMean = torch.mean(tF,dim=2)
566
+ tCov = GramMatrix()(target)
567
+
568
+ loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
569
+ return loss/tb
570
+
571
+ class GramMatrix(nn.Module):
572
+ def forward(self, input):
573
+ b, c, h, w = input.size()
574
+ f = input.view(b,c,h*w) # bxcx(hxw)
575
+ # torch.bmm(batch1, batch2, out=None) #
576
+ # batch1: bxmxp, batch2: bxpxn -> bxmxn #
577
+ G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
578
+ return G.div_(c*h*w)
579
+
580
+ class LossCriterion(nn.Module):
581
+ def __init__(self, style_layers, content_layers, style_weight, content_weight,
582
+ model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'):
583
+ super(LossCriterion,self).__init__()
584
+
585
+ self.style_layers = style_layers
586
+ self.content_layers = content_layers
587
+ self.style_weight = style_weight
588
+ self.content_weight = content_weight
589
+
590
+ self.styleLosses = [styleLoss()] * len(style_layers)
591
+ self.contentLosses = [nn.MSELoss()] * len(content_layers)
592
+
593
+ self.vgg5 = encoder5()
594
+ self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
595
+
596
+ for param in self.vgg5.parameters():
597
+ param.requires_grad = True
598
+
599
+ def forward(self, transfer, image, content=True, style=True):
600
+ cF = self.vgg5(image)
601
+ sF = self.vgg5(image)
602
+ tF = self.vgg5(transfer)
603
+
604
+ losses = {}
605
+
606
+ # content loss
607
+ if content:
608
+ totalContentLoss = 0
609
+ for i,layer in enumerate(self.content_layers):
610
+ cf_i = cF[layer]
611
+ cf_i = cf_i.detach()
612
+ tf_i = tF[layer]
613
+ loss_i = self.contentLosses[i]
614
+ totalContentLoss += loss_i(tf_i,cf_i)
615
+ totalContentLoss = totalContentLoss * self.content_weight
616
+ losses['content'] = totalContentLoss
617
+
618
+ # style loss
619
+ if style:
620
+ totalStyleLoss = 0
621
+ for i,layer in enumerate(self.style_layers):
622
+ sf_i = sF[layer]
623
+ sf_i = sf_i.detach()
624
+ tf_i = tF[layer]
625
+ loss_i = self.styleLosses[i]
626
+ totalStyleLoss += loss_i(tf_i,sf_i)
627
+ totalStyleLoss = totalStyleLoss * self.style_weight
628
+ losses['style'] = totalStyleLoss
629
+
630
+ return losses
631
+
632
+ class styleLossMask(nn.Module):
633
+ def forward(self, input, target, mask):
634
+ ib,ic,ih,iw = input.size()
635
+ iF = input.view(ib,ic,-1)
636
+ tb,tc,th,tw = target.size()
637
+ tF = target.view(tb,tc,-1)
638
+
639
+ loss = 0
640
+ mb, mc, mh, mw = mask.shape
641
+ for i in range(mb):
642
+ # resize mask to have the same size of the feature
643
+ maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest')
644
+ mask_flat = maski.view(mc, -1)
645
+ for j in range(mc):
646
+ # get features for each part
647
+ idx = torch.nonzero(mask_flat[j]).squeeze()
648
+ if len(idx.shape) == 0 or idx.shape[0] == 0:
649
+ continue
650
+ ipart = torch.index_select(iF, 2, idx)
651
+ tpart = torch.index_select(tF, 2, idx)
652
+
653
+ iMean = torch.mean(ipart,dim=2)
654
+ iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
655
+
656
+ tMean = torch.mean(tpart,dim=2)
657
+ tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
658
+
659
+ loss += nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram)
660
+ return loss/tb
661
+
662
+ class LossCriterionMask(nn.Module):
663
+ def __init__(self, style_layers, content_layers, style_weight, content_weight,
664
+ model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'):
665
+ super(LossCriterionMask,self).__init__()
666
+
667
+ self.style_layers = style_layers
668
+ self.content_layers = content_layers
669
+ self.style_weight = style_weight
670
+ self.content_weight = content_weight
671
+
672
+ self.styleLosses = [styleLossMask()] * len(style_layers)
673
+ self.contentLosses = [nn.MSELoss()] * len(content_layers)
674
+
675
+ self.vgg5 = encoder5()
676
+ self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
677
+
678
+ for param in self.vgg5.parameters():
679
+ param.requires_grad = True
680
+
681
+ def forward(self, transfer, image, mask, content=True, style=True):
682
+ # mask: B, N, H, W
683
+ cF = self.vgg5(image)
684
+ sF = self.vgg5(image)
685
+ tF = self.vgg5(transfer)
686
+
687
+ losses = {}
688
+
689
+ # content loss
690
+ if content:
691
+ totalContentLoss = 0
692
+ for i,layer in enumerate(self.content_layers):
693
+ cf_i = cF[layer]
694
+ cf_i = cf_i.detach()
695
+ tf_i = tF[layer]
696
+ loss_i = self.contentLosses[i]
697
+ totalContentLoss += loss_i(tf_i,cf_i)
698
+ totalContentLoss = totalContentLoss * self.content_weight
699
+ losses['content'] = totalContentLoss
700
+
701
+ # style loss
702
+ if style:
703
+ totalStyleLoss = 0
704
+ for i,layer in enumerate(self.style_layers):
705
+ sf_i = sF[layer]
706
+ sf_i = sf_i.detach()
707
+ tf_i = tF[layer]
708
+ loss_i = self.styleLosses[i]
709
+ totalStyleLoss += loss_i(tf_i,sf_i, mask)
710
+ totalStyleLoss = totalStyleLoss * self.style_weight
711
+ losses['style'] = totalStyleLoss
712
+
713
+ return losses
714
+
715
+ class VQEmbedding(nn.Module):
716
+ def __init__(self, K, D):
717
+ super().__init__()
718
+ self.embedding = nn.Embedding(K, D)
719
+ self.embedding.weight.data.uniform_(-1./K, 1./K)
720
+
721
+ def forward(self, z_e_x):
722
+ z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
723
+ latents = vq(z_e_x_, self.embedding.weight)
724
+ return latents
725
+
726
+ def straight_through(self, z_e_x, return_index=False):
727
+ z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
728
+ z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
729
+ z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
730
+
731
+ z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
732
+ dim=0, index=indices)
733
+ z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
734
+ z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
735
+
736
+ if return_index:
737
+ return z_q_x, z_q_x_bar, indices
738
+ else:
739
+ return z_q_x, z_q_x_bar
libs/custom_transform.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.functional as TF
6
+ import numpy as np
7
+ from PIL import Image, ImageFilter
8
+ import random
9
+
10
+ class BaseTransform(object):
11
+ """
12
+ Resize and center crop.
13
+ """
14
+ def __init__(self, res):
15
+ self.res = res
16
+
17
+ def __call__(self, index, image):
18
+ image = TF.resize(image, self.res, Image.BILINEAR)
19
+ w, h = image.size
20
+ left = int(round((w - self.res) / 2.))
21
+ top = int(round((h - self.res) / 2.))
22
+
23
+ return TF.crop(image, top, left, self.res, self.res)
24
+
25
+
26
+ class ComposeTransform(object):
27
+ def __init__(self, tlist):
28
+ self.tlist = tlist
29
+
30
+ def __call__(self, index, image):
31
+ for trans in self.tlist:
32
+ image = trans(index, image)
33
+
34
+ return image
35
+
36
+ class RandomResize(object):
37
+ def __init__(self, rmin, rmax, N):
38
+ self.reslist = [random.randint(rmin, rmax) for _ in range(N)]
39
+
40
+ def __call__(self, index, image):
41
+ return TF.resize(image, self.reslist[index], Image.BILINEAR)
42
+
43
+ class RandomCrop(object):
44
+ def __init__(self, res, N):
45
+ self.res = res
46
+ self.cons = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
47
+
48
+ def __call__(self, index, image):
49
+ ws, hs = self.cons[index]
50
+ w, h = image.size
51
+ left = int(round((w-self.res)*ws))
52
+ top = int(round((h-self.res)*hs))
53
+
54
+ return TF.crop(image, top, left, self.res, self.res)
55
+
56
+ class RandomHorizontalFlip(object):
57
+ def __init__(self, N, p=0.5):
58
+ self.p_ref = p
59
+ self.plist = np.random.random_sample(N)
60
+
61
+ def __call__(self, index, image):
62
+ if self.plist[index.cpu()] < self.p_ref:
63
+ return TF.hflip(image)
64
+ else:
65
+ return image
66
+
67
+
68
+ class TensorTransform(object):
69
+ def __init__(self):
70
+ self.to_tensor = transforms.ToTensor()
71
+ #self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
72
+
73
+ def __call__(self, image):
74
+ image = self.to_tensor(image)
75
+ #image = self.normalize(image)
76
+
77
+ return image
78
+
79
+
80
+ class RandomGaussianBlur(object):
81
+ def __init__(self, sigma, p, N):
82
+ self.min_x = sigma[0]
83
+ self.max_x = sigma[1]
84
+ self.del_p = 1 - p
85
+ self.p_ref = p
86
+ self.plist = np.random.random_sample(N)
87
+
88
+ def __call__(self, index, image):
89
+ if self.plist[index] < self.p_ref:
90
+ x = self.plist[index] - self.p_ref
91
+ m = (self.max_x - self.min_x) / self.del_p
92
+ b = self.min_x
93
+ s = m * x + b
94
+
95
+ return image.filter(ImageFilter.GaussianBlur(radius=s))
96
+ else:
97
+ return image
98
+
99
+
100
+ class RandomGrayScale(object):
101
+ def __init__(self, p, N):
102
+ self.grayscale = transforms.RandomGrayscale(p=1.) # Deterministic (We still want flexible out_dim).
103
+ self.p_ref = p
104
+ self.plist = np.random.random_sample(N)
105
+
106
+ def __call__(self, index, image):
107
+ if self.plist[index] < self.p_ref:
108
+ return self.grayscale(image)
109
+ else:
110
+ return image
111
+
112
+
113
+ class RandomColorBrightness(object):
114
+ def __init__(self, x, p, N):
115
+ self.min_x = max(0, 1 - x)
116
+ self.max_x = 1 + x
117
+ self.p_ref = p
118
+ self.plist = np.random.random_sample(N)
119
+ self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
120
+
121
+ def __call__(self, index, image):
122
+ if self.plist[index] < self.p_ref:
123
+ return TF.adjust_brightness(image, self.rlist[index])
124
+ else:
125
+ return image
126
+
127
+
128
+ class RandomColorContrast(object):
129
+ def __init__(self, x, p, N):
130
+ self.min_x = max(0, 1 - x)
131
+ self.max_x = 1 + x
132
+ self.p_ref = p
133
+ self.plist = np.random.random_sample(N)
134
+ self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
135
+
136
+ def __call__(self, index, image):
137
+ if self.plist[index] < self.p_ref:
138
+ return TF.adjust_contrast(image, self.rlist[index])
139
+ else:
140
+ return image
141
+
142
+
143
+ class RandomColorSaturation(object):
144
+ def __init__(self, x, p, N):
145
+ self.min_x = max(0, 1 - x)
146
+ self.max_x = 1 + x
147
+ self.p_ref = p
148
+ self.plist = np.random.random_sample(N)
149
+ self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
150
+
151
+ def __call__(self, index, image):
152
+ if self.plist[index] < self.p_ref:
153
+ return TF.adjust_saturation(image, self.rlist[index])
154
+ else:
155
+ return image
156
+
157
+
158
+ class RandomColorHue(object):
159
+ def __init__(self, x, p, N):
160
+ self.min_x = -x
161
+ self.max_x = x
162
+ self.p_ref = p
163
+ self.plist = np.random.random_sample(N)
164
+ self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
165
+
166
+ def __call__(self, index, image):
167
+ if self.plist[index] < self.p_ref:
168
+ return TF.adjust_hue(image, self.rlist[index])
169
+ else:
170
+ return image
171
+
172
+
173
+ class RandomVerticalFlip(object):
174
+ def __init__(self, N, p=0.5):
175
+ self.p_ref = p
176
+ self.plist = np.random.random_sample(N)
177
+
178
+ def __call__(self, indice, image):
179
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
180
+
181
+ if len(image.size()) == 3:
182
+ image_t = image[I].flip([1])
183
+ else:
184
+ image_t = image[I].flip([2])
185
+
186
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
187
+
188
+
189
+
190
+ class RandomHorizontalTensorFlip(object):
191
+ def __init__(self, N, p=0.5):
192
+ self.p_ref = p
193
+ self.plist = np.random.random_sample(N)
194
+
195
+ def __call__(self, indice, image, is_label=False):
196
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
197
+
198
+ if len(image.size()) == 3:
199
+ image_t = image[I].flip([2])
200
+ else:
201
+ image_t = image[I].flip([3])
202
+
203
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
204
+
205
+
206
+
207
+ class RandomResizedCrop(object):
208
+ def __init__(self, N, res, scale=(0.5, 1.0)):
209
+ self.res = res
210
+ self.scale = scale
211
+ self.rscale = [np.random.uniform(*scale) for _ in range(N)]
212
+ self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
213
+
214
+ def random_crop(self, idx, img):
215
+ ws, hs = self.rcrop[idx]
216
+ res1 = int(img.size(-1))
217
+ res2 = int(self.rscale[idx]*res1)
218
+ i1 = int(round((res1-res2)*ws))
219
+ j1 = int(round((res1-res2)*hs))
220
+
221
+ return img[:, :, i1:i1+res2, j1:j1+res2]
222
+
223
+
224
+ def __call__(self, indice, image):
225
+ new_image = []
226
+ res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2?
227
+
228
+ for i, idx in enumerate(indice):
229
+ img = image[[i]]
230
+ img = self.random_crop(idx, img)
231
+ img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
232
+
233
+ new_image.append(img)
234
+
235
+ new_image = torch.cat(new_image)
236
+
237
+ return new_image
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+
248
+
249
+
libs/data_coco_stuff.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from PIL import Image
4
+ import os.path as osp
5
+ import numpy as np
6
+ from torch.utils import data
7
+ import torchvision.transforms as transforms
8
+ import torchvision.transforms.functional as TF
9
+ import random
10
+
11
+ class RandomResizedCrop(object):
12
+ def __init__(self, N, res, scale=(0.5, 1.0)):
13
+ self.res = res
14
+ self.scale = scale
15
+ self.rscale = [np.random.uniform(*scale) for _ in range(N)]
16
+ self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
17
+
18
+ def random_crop(self, idx, img):
19
+ ws, hs = self.rcrop[idx]
20
+ res1 = int(img.size(-1))
21
+ res2 = int(self.rscale[idx]*res1)
22
+ i1 = int(round((res1-res2)*ws))
23
+ j1 = int(round((res1-res2)*hs))
24
+
25
+ return img[:, :, i1:i1+res2, j1:j1+res2]
26
+
27
+
28
+ def __call__(self, indice, image):
29
+ new_image = []
30
+ res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2?
31
+
32
+ for i, idx in enumerate(indice):
33
+ img = image[[i]]
34
+ img = self.random_crop(idx, img)
35
+ img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
36
+
37
+ new_image.append(img)
38
+
39
+ new_image = torch.cat(new_image)
40
+
41
+ return new_image
42
+
43
+ class RandomVerticalFlip(object):
44
+ def __init__(self, N, p=0.5):
45
+ self.p_ref = p
46
+ self.plist = np.random.random_sample(N)
47
+
48
+ def __call__(self, indice, image):
49
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
50
+
51
+ if len(image.size()) == 3:
52
+ image_t = image[I].flip([1])
53
+ else:
54
+ image_t = image[I].flip([2])
55
+
56
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
57
+
58
+ class RandomHorizontalTensorFlip(object):
59
+ def __init__(self, N, p=0.5):
60
+ self.p_ref = p
61
+ self.plist = np.random.random_sample(N)
62
+
63
+ def __call__(self, indice, image, is_label=False):
64
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
65
+
66
+ if len(image.size()) == 3:
67
+ image_t = image[I].flip([2])
68
+ else:
69
+ image_t = image[I].flip([3])
70
+
71
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
72
+
73
+ class _Coco164kCuratedFew(data.Dataset):
74
+ """Base class
75
+ This contains fields and methods common to all COCO 164k curated few datasets:
76
+
77
+ (curated) Coco164kFew_Stuff
78
+ (curated) Coco164kFew_Stuff_People
79
+ (curated) Coco164kFew_Stuff_Animals
80
+ (curated) Coco164kFew_Stuff_People_Animals
81
+
82
+ """
83
+ def __init__(self, root, img_size, crop_size, split = "train2017"):
84
+ super(_Coco164kCuratedFew, self).__init__()
85
+
86
+ # work out name
87
+ self.split = split
88
+ self.root = root
89
+ self.include_things_labels = False # people
90
+ self.incl_animal_things = False # animals
91
+
92
+ version = 6
93
+
94
+ name = "Coco164kFew_Stuff"
95
+ if self.include_things_labels and self.incl_animal_things:
96
+ name += "_People_Animals"
97
+ elif self.include_things_labels:
98
+ name += "_People"
99
+ elif self.incl_animal_things:
100
+ name += "_Animals"
101
+
102
+ self.name = (name + "_%d" % version)
103
+
104
+ print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name)
105
+
106
+ self._set_files()
107
+
108
+
109
+ self.transform = transforms.Compose([
110
+ transforms.RandomChoice([
111
+ transforms.ColorJitter(brightness=0.05),
112
+ transforms.ColorJitter(contrast=0.05),
113
+ transforms.ColorJitter(saturation=0.01),
114
+ transforms.ColorJitter(hue=0.01)]),
115
+ transforms.RandomHorizontalFlip(),
116
+ transforms.RandomVerticalFlip(),
117
+ transforms.Resize(int(img_size)),
118
+ transforms.RandomCrop(crop_size)])
119
+
120
+ N = len(self.files)
121
+ self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
122
+ self.random_vertical_flip = RandomVerticalFlip(N=N)
123
+ self.random_resized_crop = RandomResizedCrop(N=N, res=self.res1, scale=self.scale)
124
+
125
+
126
+ def _set_files(self):
127
+ # Create data list by parsing the "images" folder
128
+ if self.split in ["train2017", "val2017"]:
129
+ file_list = osp.join(self.root, "curated", self.split, self.name + ".txt")
130
+ file_list = tuple(open(file_list, "r"))
131
+ file_list = [id_.rstrip() for id_ in file_list]
132
+
133
+ self.files = file_list
134
+ print("In total {} images.".format(len(self.files)))
135
+ else:
136
+ raise ValueError("Invalid split name: {}".format(self.split))
137
+
138
+ def __getitem__(self, index):
139
+ # same as _Coco164k
140
+ # Set paths
141
+ image_id = self.files[index]
142
+ image_path = osp.join(self.root, "images", self.split, image_id + ".jpg")
143
+ label_path = osp.join(self.root, "annotations", self.split,
144
+ image_id + ".png")
145
+ # Load an image
146
+ #image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.uint8)
147
+ ori_img = Image.open(image_path)
148
+ ori_img = self.transform(ori_img)
149
+ ori_img = np.array(ori_img)
150
+ if ori_img.ndim < 3:
151
+ ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
152
+ ori_img = ori_img[:, :, :3]
153
+ ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
154
+ ori_img = ori_img / 255.0
155
+
156
+ #label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE).astype(np.int32)
157
+
158
+ #label[label == 255] = -1 # to be consistent with 10k
159
+
160
+ rets = []
161
+ rets.append(ori_img)
162
+ #rets.append(label)
163
+ return rets
164
+
165
+ def __len__(self):
166
+ return len(self.files)
libs/data_coco_stuff_geo_pho.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from PIL import Image
4
+ import os.path as osp
5
+ import numpy as np
6
+ from torch.utils import data
7
+ import torchvision.transforms as transforms
8
+ import torchvision.transforms.functional as TF
9
+ import torchvision.transforms.functional as TF
10
+ from .custom_transform import *
11
+
12
+ class _Coco164kCuratedFew(data.Dataset):
13
+ """Base class
14
+ This contains fields and methods common to all COCO 164k curated few datasets:
15
+
16
+ (curated) Coco164kFew_Stuff
17
+ (curated) Coco164kFew_Stuff_People
18
+ (curated) Coco164kFew_Stuff_Animals
19
+ (curated) Coco164kFew_Stuff_People_Animals
20
+
21
+ """
22
+ def __init__(self, root, img_size, crop_size, split = "train2017"):
23
+ super(_Coco164kCuratedFew, self).__init__()
24
+
25
+ # work out name
26
+ self.split = split
27
+ self.root = root
28
+ self.include_things_labels = False # people
29
+ self.incl_animal_things = False # animals
30
+
31
+ version = 6
32
+
33
+ name = "Coco164kFew_Stuff"
34
+ if self.include_things_labels and self.incl_animal_things:
35
+ name += "_People_Animals"
36
+ elif self.include_things_labels:
37
+ name += "_People"
38
+ elif self.incl_animal_things:
39
+ name += "_Animals"
40
+
41
+ self.name = (name + "_%d" % version)
42
+
43
+ print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name)
44
+
45
+ self._set_files()
46
+
47
+ self.transform = transforms.Compose([
48
+ transforms.Resize(int(img_size)),
49
+ transforms.RandomCrop(crop_size)])
50
+
51
+ N = len(self.files)
52
+ # eqv transform
53
+ self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
54
+ self.random_vertical_flip = RandomVerticalFlip(N=N)
55
+ self.random_resized_crop = RandomResizedCrop(N=N, res=288)
56
+
57
+ # photometric transform
58
+ self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)]
59
+ self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
60
+ self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
61
+ self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
62
+ self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)]
63
+ self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)]
64
+
65
+ self.eqv_list = ['random_crop', 'h_flip']
66
+ self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur']
67
+
68
+ self.transform_tensor = TensorTransform()
69
+
70
+
71
+ def _set_files(self):
72
+ # Create data list by parsing the "images" folder
73
+ if self.split in ["train2017", "val2017"]:
74
+ file_list = osp.join(self.root, "curated", self.split, self.name + ".txt")
75
+ file_list = tuple(open(file_list, "r"))
76
+ file_list = [id_.rstrip() for id_ in file_list]
77
+
78
+ self.files = file_list
79
+ print("In total {} images.".format(len(self.files)))
80
+ else:
81
+ raise ValueError("Invalid split name: {}".format(self.split))
82
+
83
+ def transform_eqv(self, indice, image):
84
+ if 'random_crop' in self.eqv_list:
85
+ image = self.random_resized_crop(indice, image)
86
+ if 'h_flip' in self.eqv_list:
87
+ image = self.random_horizontal_flip(indice, image)
88
+ if 'v_flip' in self.eqv_list:
89
+ image = self.random_vertical_flip(indice, image)
90
+
91
+ return image
92
+
93
+ def transform_inv(self, index, image, ver):
94
+ """
95
+ Hyperparameters same as MoCo v2.
96
+ (https://github.com/facebookresearch/moco/blob/master/main_moco.py)
97
+ """
98
+ if 'brightness' in self.inv_list:
99
+ image = self.random_color_brightness[ver](index, image)
100
+ if 'contrast' in self.inv_list:
101
+ image = self.random_color_contrast[ver](index, image)
102
+ if 'saturation' in self.inv_list:
103
+ image = self.random_color_saturation[ver](index, image)
104
+ if 'hue' in self.inv_list:
105
+ image = self.random_color_hue[ver](index, image)
106
+ if 'gray' in self.inv_list:
107
+ image = self.random_gray_scale[ver](index, image)
108
+ if 'blur' in self.inv_list:
109
+ image = self.random_gaussian_blur[ver](index, image)
110
+
111
+ return image
112
+
113
+ def transform_image(self, index, image):
114
+ image1 = self.transform_inv(index, image, 0)
115
+ image1 = self.transform_tensor(image)
116
+
117
+ image2 = self.transform_inv(index, image, 1)
118
+ #image2 = TF.resize(image2, self.crop_size, Image.BILINEAR)
119
+ image2 = self.transform_tensor(image2)
120
+ return image1, image2
121
+
122
+ def __getitem__(self, index):
123
+ # same as _Coco164k
124
+ # Set paths
125
+ image_id = self.files[index]
126
+ image_path = osp.join(self.root, "images", self.split, image_id + ".jpg")
127
+ # Load an image
128
+ ori_img = Image.open(image_path)
129
+ ori_img = self.transform(ori_img)
130
+
131
+ image1, image2 = self.transform_image(index, ori_img)
132
+ if image1.shape[0] < 3:
133
+ image1 = image1.repeat(3, 1, 1)
134
+ if image2.shape[0] < 3:
135
+ image2 = image2.repeat(3, 1, 1)
136
+
137
+ rets = []
138
+ rets.append(image1)
139
+ rets.append(image2)
140
+ rets.append(index)
141
+
142
+ return rets
143
+
144
+ def __len__(self):
145
+ return len(self.files)
libs/data_geo.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SLIC dataset
2
+ - Returns an image together with its SLIC segmentation map.
3
+ """
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torchvision.transforms as transforms
7
+
8
+ import numpy as np
9
+ from glob import glob
10
+ from PIL import Image
11
+ from skimage.segmentation import slic
12
+ from skimage.color import rgb2lab
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import label2one_hot_torch
16
+
17
+ class RandomResizedCrop(object):
18
+ def __init__(self, N, res, scale=(0.5, 1.0)):
19
+ self.res = res
20
+ self.scale = scale
21
+ self.rscale = [np.random.uniform(*scale) for _ in range(N)]
22
+ self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
23
+
24
+ def random_crop(self, idx, img):
25
+ ws, hs = self.rcrop[idx]
26
+ res1 = int(img.size(-1))
27
+ res2 = int(self.rscale[idx]*res1)
28
+ i1 = int(round((res1-res2)*ws))
29
+ j1 = int(round((res1-res2)*hs))
30
+
31
+ return img[:, :, i1:i1+res2, j1:j1+res2]
32
+
33
+
34
+ def __call__(self, indice, image):
35
+ new_image = []
36
+ res_tar = self.res // 8 if image.size(1) > 5 else self.res # View 1 or View 2?
37
+
38
+ for i, idx in enumerate(indice):
39
+ img = image[[i]]
40
+ img = self.random_crop(idx, img)
41
+ img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
42
+
43
+ new_image.append(img)
44
+
45
+ new_image = torch.cat(new_image)
46
+
47
+ return new_image
48
+
49
+ class RandomVerticalFlip(object):
50
+ def __init__(self, N, p=0.5):
51
+ self.p_ref = p
52
+ self.plist = np.random.random_sample(N)
53
+
54
+ def __call__(self, indice, image):
55
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
56
+
57
+ if len(image.size()) == 3:
58
+ image_t = image[I].flip([1])
59
+ else:
60
+ image_t = image[I].flip([2])
61
+
62
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
63
+
64
+ class RandomHorizontalTensorFlip(object):
65
+ def __init__(self, N, p=0.5):
66
+ self.p_ref = p
67
+ self.plist = np.random.random_sample(N)
68
+
69
+ def __call__(self, indice, image, is_label=False):
70
+ I = np.nonzero(self.plist[indice.cpu()] < self.p_ref)[0]
71
+
72
+ if len(image.size()) == 3:
73
+ image_t = image[I].flip([2])
74
+ else:
75
+ image_t = image[I].flip([3])
76
+
77
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
78
+
79
+ class Dataset(data.Dataset):
80
+ def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
81
+ sp_num=256, slic = True, lab = False):
82
+ super(Dataset, self).__init__()
83
+ #self.data_list = glob(os.path.join(data_dir, "*.jpg"))
84
+ ext = ["*.jpg"]
85
+ dl = []
86
+ [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
87
+ self.data_list = dl
88
+ self.sp_num = sp_num
89
+ self.slic = slic
90
+ self.lab = lab
91
+ if test:
92
+ self.transform = transforms.Compose([
93
+ transforms.Resize(img_size),
94
+ transforms.CenterCrop(crop_size)])
95
+ else:
96
+ self.transform = transforms.Compose([
97
+ transforms.RandomChoice([
98
+ transforms.ColorJitter(brightness=0.05),
99
+ transforms.ColorJitter(contrast=0.05),
100
+ transforms.ColorJitter(saturation=0.01),
101
+ transforms.ColorJitter(hue=0.01)]),
102
+ transforms.RandomHorizontalFlip(),
103
+ transforms.RandomVerticalFlip(),
104
+ transforms.Resize(int(img_size)),
105
+ transforms.RandomCrop(crop_size)])
106
+
107
+ N = len(self.data_list)
108
+ self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
109
+ self.random_vertical_flip = RandomVerticalFlip(N=N)
110
+ self.random_resized_crop = RandomResizedCrop(N=N, res=224)
111
+ self.eqv_list = ['random_crop', 'h_flip']
112
+
113
+ def transform_eqv(self, indice, image):
114
+ if 'random_crop' in self.eqv_list:
115
+ image = self.random_resized_crop(indice, image)
116
+ if 'h_flip' in self.eqv_list:
117
+ image = self.random_horizontal_flip(indice, image)
118
+ if 'v_flip' in self.eqv_list:
119
+ image = self.random_vertical_flip(indice, image)
120
+
121
+ return image
122
+
123
+ def __getitem__(self, index):
124
+ data_path = self.data_list[index]
125
+ ori_img = Image.open(data_path)
126
+ ori_img = self.transform(ori_img)
127
+ ori_img = np.array(ori_img)
128
+
129
+ # compute slic
130
+ if self.slic:
131
+ slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3)
132
+ slic_i = torch.from_numpy(slic_i)
133
+ slic_i[slic_i >= self.sp_num] = self.sp_num - 1
134
+ oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze()
135
+
136
+ if ori_img.ndim < 3:
137
+ ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
138
+ ori_img = ori_img[:, :, :3]
139
+
140
+ rets = []
141
+ if self.lab:
142
+ lab_img = rgb2lab(ori_img)
143
+ rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1))
144
+
145
+ ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
146
+ rets.append(ori_img/255.0)
147
+
148
+ if self.slic:
149
+ rets.append(oh)
150
+
151
+ rets.append(index)
152
+
153
+ return rets
154
+
155
+ def __len__(self):
156
+ return len(self.data_list)
157
+
158
+ if __name__ == '__main__':
159
+ import torchvision.utils as vutils
160
+ dataset = Dataset('/home/xtli/DATA/texture_data/',
161
+ sampled_num=3000)
162
+ loader_ = torch.utils.data.DataLoader(dataset = dataset,
163
+ batch_size = 1,
164
+ shuffle = True,
165
+ num_workers = 1,
166
+ drop_last = True)
167
+ loader = iter(loader_)
168
+ img, points, pixs = loader.next()
169
+
170
+ crop_size = 128
171
+ canvas = torch.zeros((1, 3, crop_size, crop_size))
172
+ for i in range(points.shape[-2]):
173
+ p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
174
+ canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
175
+ vutils.save_image(canvas, 'canvas.png')
176
+ vutils.save_image(img, 'img.png')
libs/data_geo_pho.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SLIC dataset
2
+ - Returns an image together with its SLIC segmentation map.
3
+ """
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torchvision.transforms as transforms
7
+
8
+ import numpy as np
9
+ from glob import glob
10
+ from PIL import Image
11
+ import torch.nn.functional as F
12
+ import torchvision.transforms.functional as TF
13
+
14
+ from .custom_transform import *
15
+
16
+ class Dataset(data.Dataset):
17
+ def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
18
+ sp_num=256, slic = True, lab = False):
19
+ super(Dataset, self).__init__()
20
+ #self.data_list = glob(os.path.join(data_dir, "*.jpg"))
21
+ ext = ["*.jpg"]
22
+ dl = []
23
+ [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
24
+ self.data_list = dl
25
+ self.sp_num = sp_num
26
+ self.slic = slic
27
+ self.lab = lab
28
+ if test:
29
+ self.transform = transforms.Compose([
30
+ transforms.Resize(img_size),
31
+ transforms.CenterCrop(crop_size)])
32
+ else:
33
+ self.transform = transforms.Compose([
34
+ transforms.Resize(int(img_size)),
35
+ transforms.RandomCrop(crop_size)])
36
+
37
+ N = len(self.data_list)
38
+ # eqv transform
39
+ self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
40
+ self.random_vertical_flip = RandomVerticalFlip(N=N)
41
+ self.random_resized_crop = RandomResizedCrop(N=N, res=256)
42
+
43
+ # photometric transform
44
+ self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)]
45
+ self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
46
+ self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
47
+ self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
48
+ self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)]
49
+ self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)]
50
+
51
+ self.eqv_list = ['random_crop', 'h_flip']
52
+ self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur']
53
+
54
+ self.transform_tensor = TensorTransform()
55
+
56
+ def transform_eqv(self, indice, image):
57
+ if 'random_crop' in self.eqv_list:
58
+ image = self.random_resized_crop(indice, image)
59
+ if 'h_flip' in self.eqv_list:
60
+ image = self.random_horizontal_flip(indice, image)
61
+ if 'v_flip' in self.eqv_list:
62
+ image = self.random_vertical_flip(indice, image)
63
+
64
+ return image
65
+
66
+ def transform_inv(self, index, image, ver):
67
+ """
68
+ Hyperparameters same as MoCo v2.
69
+ (https://github.com/facebookresearch/moco/blob/master/main_moco.py)
70
+ """
71
+ if 'brightness' in self.inv_list:
72
+ image = self.random_color_brightness[ver](index, image)
73
+ if 'contrast' in self.inv_list:
74
+ image = self.random_color_contrast[ver](index, image)
75
+ if 'saturation' in self.inv_list:
76
+ image = self.random_color_saturation[ver](index, image)
77
+ if 'hue' in self.inv_list:
78
+ image = self.random_color_hue[ver](index, image)
79
+ if 'gray' in self.inv_list:
80
+ image = self.random_gray_scale[ver](index, image)
81
+ if 'blur' in self.inv_list:
82
+ image = self.random_gaussian_blur[ver](index, image)
83
+
84
+ return image
85
+
86
+ def transform_image(self, index, image):
87
+ image1 = self.transform_inv(index, image, 0)
88
+ image1 = self.transform_tensor(image)
89
+
90
+ image2 = self.transform_inv(index, image, 1)
91
+ #image2 = TF.resize(image2, self.crop_size, Image.BILINEAR)
92
+ image2 = self.transform_tensor(image2)
93
+ return image1, image2
94
+
95
+ def __getitem__(self, index):
96
+ data_path = self.data_list[index]
97
+ ori_img = Image.open(data_path)
98
+ ori_img = self.transform(ori_img)
99
+
100
+ image1, image2 = self.transform_image(index, ori_img)
101
+
102
+ rets = []
103
+ rets.append(image1)
104
+ rets.append(image2)
105
+ rets.append(index)
106
+
107
+ return rets
108
+
109
+ def __len__(self):
110
+ return len(self.data_list)
111
+
112
+ if __name__ == '__main__':
113
+ import torchvision.utils as vutils
114
+ dataset = Dataset('/home/xtli/DATA/texture_data/',
115
+ sampled_num=3000)
116
+ loader_ = torch.utils.data.DataLoader(dataset = dataset,
117
+ batch_size = 1,
118
+ shuffle = True,
119
+ num_workers = 1,
120
+ drop_last = True)
121
+ loader = iter(loader_)
122
+ img, points, pixs = loader.next()
123
+
124
+ crop_size = 128
125
+ canvas = torch.zeros((1, 3, crop_size, crop_size))
126
+ for i in range(points.shape[-2]):
127
+ p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
128
+ canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
129
+ vutils.save_image(canvas, 'canvas.png')
130
+ vutils.save_image(img, 'img.png')
libs/data_slic.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SLIC dataset
2
+ - Returns an image together with its SLIC segmentation map.
3
+ """
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torchvision.transforms as transforms
7
+
8
+ import numpy as np
9
+ from glob import glob
10
+ from PIL import Image
11
+ from skimage.segmentation import slic
12
+ from skimage.color import rgb2lab
13
+
14
+ from .utils import label2one_hot_torch
15
+
16
+ class RandomResizedCrop(object):
17
+ def __init__(self, N, res, scale=(0.5, 1.0)):
18
+ self.res = res
19
+ self.scale = scale
20
+ self.rscale = [np.random.uniform(*scale) for _ in range(N)]
21
+ self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
22
+
23
+ def random_crop(self, idx, img):
24
+ ws, hs = self.rcrop[idx]
25
+ res1 = int(img.size(-1))
26
+ res2 = int(self.rscale[idx]*res1)
27
+ i1 = int(round((res1-res2)*ws))
28
+ j1 = int(round((res1-res2)*hs))
29
+
30
+ return img[:, :, i1:i1+res2, j1:j1+res2]
31
+
32
+
33
+ def __call__(self, indice, image):
34
+ new_image = []
35
+ res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2?
36
+
37
+ for i, idx in enumerate(indice):
38
+ img = image[[i]]
39
+ img = self.random_crop(idx, img)
40
+ img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
41
+
42
+ new_image.append(img)
43
+
44
+ new_image = torch.cat(new_image)
45
+
46
+ return new_image
47
+
48
+ class RandomVerticalFlip(object):
49
+ def __init__(self, N, p=0.5):
50
+ self.p_ref = p
51
+ self.plist = np.random.random_sample(N)
52
+
53
+ def __call__(self, indice, image):
54
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
55
+
56
+ if len(image.size()) == 3:
57
+ image_t = image[I].flip([1])
58
+ else:
59
+ image_t = image[I].flip([2])
60
+
61
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
62
+
63
+ class RandomHorizontalTensorFlip(object):
64
+ def __init__(self, N, p=0.5):
65
+ self.p_ref = p
66
+ self.plist = np.random.random_sample(N)
67
+
68
+ def __call__(self, indice, image, is_label=False):
69
+ I = np.nonzero(self.plist[indice] < self.p_ref)[0]
70
+
71
+ if len(image.size()) == 3:
72
+ image_t = image[I].flip([2])
73
+ else:
74
+ image_t = image[I].flip([3])
75
+
76
+ return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
77
+
78
+ class Dataset(data.Dataset):
79
+ def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
80
+ sp_num=256, slic = True, lab = False):
81
+ super(Dataset, self).__init__()
82
+ #self.data_list = glob(os.path.join(data_dir, "*.jpg"))
83
+ ext = ["*.jpg"]
84
+ dl = []
85
+ [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
86
+ self.data_list = dl
87
+ self.sp_num = sp_num
88
+ self.slic = slic
89
+ self.lab = lab
90
+ if test:
91
+ self.transform = transforms.Compose([
92
+ transforms.Resize(img_size),
93
+ transforms.CenterCrop(crop_size)])
94
+ else:
95
+ self.transform = transforms.Compose([
96
+ transforms.RandomChoice([
97
+ transforms.ColorJitter(brightness=0.05),
98
+ transforms.ColorJitter(contrast=0.05),
99
+ transforms.ColorJitter(saturation=0.01),
100
+ transforms.ColorJitter(hue=0.01)]),
101
+ transforms.RandomHorizontalFlip(),
102
+ transforms.RandomVerticalFlip(),
103
+ transforms.Resize(int(img_size)),
104
+ transforms.RandomCrop(crop_size)])
105
+
106
+ N = len(self.data_list)
107
+ self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
108
+ self.random_vertical_flip = RandomVerticalFlip(N=N)
109
+ self.random_resized_crop = RandomResizedCrop(N=N, res=img_size)
110
+ self.eqv_list = ['random_crop', 'h_flip']
111
+
112
+ def transform_eqv(self, indice, image):
113
+ if 'random_crop' in self.eqv_list:
114
+ image = self.random_resized_crop(indice, image)
115
+ if 'h_flip' in self.eqv_list:
116
+ image = self.random_horizontal_flip(indice, image)
117
+ if 'v_flip' in self.eqv_list:
118
+ image = self.random_vertical_flip(indice, image)
119
+
120
+ return image
121
+
122
+ def __getitem__(self, index):
123
+ data_path = self.data_list[index]
124
+ ori_img = Image.open(data_path)
125
+ ori_img = self.transform(ori_img)
126
+ ori_img = np.array(ori_img)
127
+
128
+ # compute slic
129
+ if self.slic:
130
+ slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3)
131
+ slic_i = torch.from_numpy(slic_i)
132
+ slic_i[slic_i >= self.sp_num] = self.sp_num - 1
133
+ oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze()
134
+
135
+ if ori_img.ndim < 3:
136
+ ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
137
+ ori_img = ori_img[:, :, :3]
138
+
139
+ rets = []
140
+ if self.lab:
141
+ lab_img = rgb2lab(ori_img)
142
+ rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1))
143
+
144
+ ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
145
+ rets.append(ori_img/255.0)
146
+
147
+ if self.slic:
148
+ rets.append(oh)
149
+
150
+ rets.append(index)
151
+
152
+ return rets
153
+
154
+ def __len__(self):
155
+ return len(self.data_list)
156
+
157
+ if __name__ == '__main__':
158
+ import torchvision.utils as vutils
159
+ dataset = Dataset('/home/xtli/DATA/texture_data/',
160
+ sampled_num=3000)
161
+ loader_ = torch.utils.data.DataLoader(dataset = dataset,
162
+ batch_size = 1,
163
+ shuffle = True,
164
+ num_workers = 1,
165
+ drop_last = True)
166
+ loader = iter(loader_)
167
+ img, points, pixs = loader.next()
168
+
169
+ crop_size = 128
170
+ canvas = torch.zeros((1, 3, crop_size, crop_size))
171
+ for i in range(points.shape[-2]):
172
+ p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
173
+ canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
174
+ vutils.save_image(canvas, 'canvas.png')
175
+ vutils.save_image(img, 'img.png')
libs/discriminator.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+ def weights_init(m):
5
+ classname = m.__class__.__name__
6
+ if classname.find('Conv') != -1:
7
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
8
+ elif classname.find('BatchNorm') != -1:
9
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
10
+ nn.init.constant_(m.bias.data, 0)
11
+
12
+
13
+ class NLayerDiscriminator(nn.Module):
14
+ """Defines a PatchGAN discriminator as in Pix2Pix
15
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
16
+ """
17
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
18
+ """Construct a PatchGAN discriminator
19
+ Parameters:
20
+ input_nc (int) -- the number of channels in input images
21
+ ndf (int) -- the number of filters in the last conv layer
22
+ n_layers (int) -- the number of conv layers in the discriminator
23
+ norm_layer -- normalization layer
24
+ """
25
+ super(NLayerDiscriminator, self).__init__()
26
+ norm_layer = nn.BatchNorm2d
27
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
28
+ use_bias = norm_layer.func != nn.BatchNorm2d
29
+ else:
30
+ use_bias = norm_layer != nn.BatchNorm2d
31
+
32
+ kw = 4
33
+ padw = 1
34
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
35
+ nf_mult = 1
36
+ nf_mult_prev = 1
37
+ for n in range(1, n_layers): # gradually increase the number of filters
38
+ nf_mult_prev = nf_mult
39
+ nf_mult = min(2 ** n, 8)
40
+ sequence += [
41
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
42
+ norm_layer(ndf * nf_mult),
43
+ nn.LeakyReLU(0.2, True)
44
+ ]
45
+
46
+ nf_mult_prev = nf_mult
47
+ nf_mult = min(2 ** n_layers, 8)
48
+ sequence += [
49
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
50
+ norm_layer(ndf * nf_mult),
51
+ nn.LeakyReLU(0.2, True)
52
+ ]
53
+
54
+ sequence += [
55
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
56
+ self.main = nn.Sequential(*sequence)
57
+
58
+ def forward(self, input):
59
+ """Standard forward."""
60
+ return self.main(input)
libs/flow_transforms.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import numbers
6
+ import types
7
+ import scipy.ndimage as ndimage
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ # import torchvision.transforms.functional as FF
12
+
13
+ '''
14
+ Data argumentation file
15
+ modifed from
16
+ https://github.com/ClementPinard/FlowNetPytorch
17
+
18
+
19
+ '''
20
+
21
+
22
+
23
+ '''Set of tranform random routines that takes both input and target as arguments,
24
+ in order to have random but coherent transformations.
25
+ inputs are PIL Image pairs and targets are ndarrays'''
26
+
27
+ _pil_interpolation_to_str = {
28
+ Image.NEAREST: 'PIL.Image.NEAREST',
29
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
30
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
31
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
32
+ Image.HAMMING: 'PIL.Image.HAMMING',
33
+ Image.BOX: 'PIL.Image.BOX',
34
+ }
35
+
36
+ class Compose(object):
37
+ """ Composes several co_transforms together.
38
+ For example:
39
+ >>> co_transforms.Compose([
40
+ >>> co_transforms.CenterCrop(10),
41
+ >>> co_transforms.ToTensor(),
42
+ >>> ])
43
+ """
44
+
45
+ def __init__(self, co_transforms):
46
+ self.co_transforms = co_transforms
47
+
48
+ def __call__(self, input, target):
49
+ for t in self.co_transforms:
50
+ input,target = t(input,target)
51
+ return input,target
52
+
53
+
54
+ class ArrayToTensor(object):
55
+ """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
56
+
57
+ def __call__(self, array):
58
+ assert(isinstance(array, np.ndarray))
59
+
60
+ array = np.transpose(array, (2, 0, 1))
61
+ # handle numpy array
62
+ tensor = torch.from_numpy(array)
63
+ # put it from HWC to CHW format
64
+
65
+ return tensor.float()
66
+
67
+
68
+ class ArrayToPILImage(object):
69
+ """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
70
+
71
+ def __call__(self, array):
72
+ assert(isinstance(array, np.ndarray))
73
+
74
+ img = Image.fromarray(array.astype(np.uint8))
75
+
76
+ return img
77
+
78
+ class PILImageToTensor(object):
79
+ """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
80
+
81
+ def __call__(self, img):
82
+ assert(isinstance(img, Image.Image))
83
+
84
+ array = np.asarray(img)
85
+ array = np.transpose(array, (2, 0, 1))
86
+ tensor = torch.from_numpy(array)
87
+
88
+ return tensor.float()
89
+
90
+
91
+ class Lambda(object):
92
+ """Applies a lambda as a transform"""
93
+
94
+ def __init__(self, lambd):
95
+ assert isinstance(lambd, types.LambdaType)
96
+ self.lambd = lambd
97
+
98
+ def __call__(self, input,target):
99
+ return self.lambd(input,target)
100
+
101
+
102
+ class CenterCrop(object):
103
+ """Crops the given inputs and target arrays at the center to have a region of
104
+ the given size. size can be a tuple (target_height, target_width)
105
+ or an integer, in which case the target will be of a square shape (size, size)
106
+ Careful, img1 and img2 may not be the same size
107
+ """
108
+
109
+ def __init__(self, size):
110
+ if isinstance(size, numbers.Number):
111
+ self.size = (int(size), int(size))
112
+ else:
113
+ self.size = size
114
+
115
+ def __call__(self, inputs, target):
116
+ h1, w1, _ = inputs[0].shape
117
+ # h2, w2, _ = inputs[1].shape
118
+ th, tw = self.size
119
+ x1 = int(round((w1 - tw) / 2.))
120
+ y1 = int(round((h1 - th) / 2.))
121
+ # x2 = int(round((w2 - tw) / 2.))
122
+ # y2 = int(round((h2 - th) / 2.))
123
+ for i in range(len(inputs)):
124
+ inputs[i] = inputs[i][y1: y1 + th, x1: x1 + tw]
125
+ # inputs[0] = inputs[0][y1: y1 + th, x1: x1 + tw]
126
+ # inputs[1] = inputs[1][y2: y2 + th, x2: x2 + tw]
127
+ target = target[y1: y1 + th, x1: x1 + tw]
128
+ return inputs,target
129
+
130
+ class myRandomResized(object):
131
+ """
132
+ based on RandomResizedCrop in
133
+ https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#RandomResizedCrop
134
+ """
135
+
136
+ def __init__(self, expect_min_size, scale=(0.8, 1.5), interpolation=cv2.INTER_NEAREST):
137
+ # assert (min(input_size) * min(scale) > max(expect_size))
138
+ # one consider one decimal !!
139
+ assert (isinstance(scale,tuple) and len(scale)==2)
140
+ self.interpolation = interpolation
141
+ self.scale = [ x*0.1 for x in range(int(scale[0]*10),int(scale[1])*10 )]
142
+ self.min_size = expect_min_size
143
+
144
+ @staticmethod
145
+ def get_params(img, scale, min_size):
146
+ """Get parameters for ``crop`` for a random sized crop.
147
+
148
+ Args:
149
+ img (PIL Image): Image to be cropped.
150
+ scale (tuple): range of size of the origin size cropped
151
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
152
+
153
+ Returns:
154
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
155
+ sized crop.
156
+ """
157
+ # area = img.size[0] * img.size[1]
158
+ h, w, _ = img.shape
159
+ for attempt in range(10):
160
+ rand_scale_ = random.choice(scale)
161
+
162
+ if random.random() < 0.5:
163
+ rand_scale = rand_scale_
164
+ else:
165
+ rand_scale = -1.
166
+
167
+ if min_size[0] <= rand_scale * h and min_size[1] <= rand_scale * w\
168
+ and rand_scale * h % 16 == 0 and rand_scale * w %16 ==0 :
169
+ # the 16*n condition is for network architecture
170
+ return (int(rand_scale * h),int(rand_scale * w ))
171
+
172
+ # Fallback
173
+ return (h, w)
174
+
175
+ def __call__(self, inputs, tgt):
176
+ """
177
+ Args:
178
+ img (PIL Image): Image to be cropped and resized.
179
+
180
+ Returns:
181
+ PIL Image: Randomly cropped and resized image.
182
+ """
183
+ h,w = self.get_params(inputs[0], self.scale, self.min_size)
184
+ for i in range(len(inputs)):
185
+ inputs[i] = cv2.resize(inputs[i], (w,h), self.interpolation)
186
+
187
+ tgt = cv2.resize(tgt, (w,h), self.interpolation) #for input as h*w*1 the output is h*w
188
+ return inputs, np.expand_dims(tgt,-1)
189
+
190
+ def __repr__(self):
191
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
192
+ format_string = self.__class__.__name__ + '(min_size={0}'.format(self.min_size)
193
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
194
+ format_string += ', interpolation={0})'.format(interpolate_str)
195
+ return format_string
196
+
197
+
198
+ class Scale(object):
199
+ """ Rescales the inputs and target arrays to the given 'size'.
200
+ 'size' will be the size of the smaller edge.
201
+ For example, if height > width, then image will be
202
+ rescaled to (size * height / width, size)
203
+ size: size of the smaller edge
204
+ interpolation order: Default: 2 (bilinear)
205
+ """
206
+
207
+ def __init__(self, size, order=2):
208
+ self.size = size
209
+ self.order = order
210
+
211
+ def __call__(self, inputs, target):
212
+ h, w, _ = inputs[0].shape
213
+ if (w <= h and w == self.size) or (h <= w and h == self.size):
214
+ return inputs,target
215
+ if w < h:
216
+ ratio = self.size/w
217
+ else:
218
+ ratio = self.size/h
219
+
220
+ for i in range(len(inputs)):
221
+ inputs[i] = ndimage.interpolation.zoom(inputs[i], ratio, order=self.order)[:, :, :3]
222
+
223
+ target = ndimage.interpolation.zoom(target, ratio, order=self.order)[:, :, :1]
224
+ #target *= ratio
225
+ return inputs, target
226
+
227
+
228
+ class RandomCrop(object):
229
+ """Crops the given PIL.Image at a random location to have a region of
230
+ the given size. size can be a tuple (target_height, target_width)
231
+ or an integer, in which case the target will be of a square shape (size, size)
232
+ """
233
+
234
+ def __init__(self, size):
235
+ if isinstance(size, numbers.Number):
236
+ self.size = (int(size), int(size))
237
+ else:
238
+ self.size = size
239
+
240
+ def __call__(self, inputs,target):
241
+ h, w, _ = inputs[0].shape
242
+ th, tw = self.size
243
+ if w == tw and h == th:
244
+ return inputs,target
245
+
246
+ x1 = random.randint(0, w - tw)
247
+ y1 = random.randint(0, h - th)
248
+ for i in range(len(inputs)):
249
+ inputs[i] = inputs[i][y1: y1 + th,x1: x1 + tw]
250
+ # inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw]
251
+ # inputs[2] = inputs[2][y1: y1 + th, x1: x1 + tw]
252
+
253
+ return inputs, target[y1: y1 + th,x1: x1 + tw]
254
+
255
+ class MyScale(object):
256
+ def __init__(self, size, order=2):
257
+ self.size = size
258
+ self.order = order
259
+
260
+ def __call__(self, inputs, target):
261
+ h, w, _ = inputs[0].shape
262
+ if (w <= h and w == self.size) or (h <= w and h == self.size):
263
+ return inputs,target
264
+ if w < h:
265
+ for i in range(len(inputs)):
266
+ inputs[i] = cv2.resize(inputs[i], (self.size, int(h * self.size / w)))
267
+ target = cv2.resize(target.squeeze(), (self.size, int(h * self.size / w)), cv2.INTER_NEAREST)
268
+ else:
269
+ for i in range(len(inputs)):
270
+ inputs[i] = cv2.resize(inputs[i], (int(w * self.size / h), self.size))
271
+ target = cv2.resize(target.squeeze(), (int(w * self.size / h), self.size), cv2.INTER_NEAREST)
272
+ target = np.expand_dims(target, axis=2)
273
+ return inputs, target
274
+
275
+ class RandomHorizontalFlip(object):
276
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
277
+ """
278
+
279
+ def __call__(self, inputs, target):
280
+ if random.random() < 0.5:
281
+ for i in range(len(inputs)):
282
+ inputs[i] = np.copy(np.fliplr(inputs[i]))
283
+ # inputs[1] = np.copy(np.fliplr(inputs[1]))
284
+ # inputs[2] = np.copy(np.fliplr(inputs[2]))
285
+
286
+ target = np.copy(np.fliplr(target))
287
+ # target[:,:,0] *= -1
288
+ return inputs,target
289
+
290
+
291
+ class RandomVerticalFlip(object):
292
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
293
+ """
294
+
295
+ def __call__(self, inputs, target):
296
+ if random.random() < 0.5:
297
+ for i in range(len(inputs)):
298
+ inputs[i] = np.copy(np.flipud(inputs[i]))
299
+ # inputs[1] = np.copy(np.flipud(inputs[1]))
300
+ # inputs[2] = np.copy(np.flipud(inputs[2]))
301
+
302
+ target = np.copy(np.flipud(target))
303
+ # target[:,:,1] *= -1 #for disp there is no y dim
304
+ return inputs,target
305
+
306
+
307
+ class RandomRotate(object):
308
+ """Random rotation of the image from -angle to angle (in degrees)
309
+ This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
310
+ angle: max angle of the rotation
311
+ interpolation order: Default: 2 (bilinear)
312
+ reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
313
+ diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
314
+ """
315
+
316
+ def __init__(self, angle, diff_angle=0, order=2, reshape=False):
317
+ self.angle = angle
318
+ self.reshape = reshape
319
+ self.order = order
320
+ self.diff_angle = diff_angle
321
+
322
+ def __call__(self, inputs,target):
323
+ applied_angle = random.uniform(-self.angle,self.angle)
324
+ diff = random.uniform(-self.diff_angle,self.diff_angle)
325
+ angle1 = applied_angle - diff/2
326
+ angle2 = applied_angle + diff/2
327
+ angle1_rad = angle1*np.pi/180
328
+
329
+ h, w, _ = target.shape
330
+
331
+ def rotate_flow(i,j,k):
332
+ return -k*(j-w/2)*(diff*np.pi/180) + (1-k)*(i-h/2)*(diff*np.pi/180)
333
+
334
+ rotate_flow_map = np.fromfunction(rotate_flow, target.shape)
335
+ target += rotate_flow_map
336
+
337
+ inputs[0] = ndimage.interpolation.rotate(inputs[0], angle1, reshape=self.reshape, order=self.order)
338
+ inputs[1] = ndimage.interpolation.rotate(inputs[1], angle2, reshape=self.reshape, order=self.order)
339
+ target = ndimage.interpolation.rotate(target, angle1, reshape=self.reshape, order=self.order)
340
+ # flow vectors must be rotated too! careful about Y flow which is upside down
341
+ target_ = np.copy(target)
342
+ target[:,:,0] = np.cos(angle1_rad)*target_[:,:,0] + np.sin(angle1_rad)*target_[:,:,1]
343
+ target[:,:,1] = -np.sin(angle1_rad)*target_[:,:,0] + np.cos(angle1_rad)*target_[:,:,1]
344
+ return inputs,target
345
+
346
+
347
+ class RandomTranslate(object):
348
+ def __init__(self, translation):
349
+ if isinstance(translation, numbers.Number):
350
+ self.translation = (int(translation), int(translation))
351
+ else:
352
+ self.translation = translation
353
+
354
+ def __call__(self, inputs,target):
355
+ h, w, _ = inputs[0].shape
356
+ th, tw = self.translation
357
+ tw = random.randint(-tw, tw)
358
+ th = random.randint(-th, th)
359
+ if tw == 0 and th == 0:
360
+ return inputs, target
361
+ # compute x1,x2,y1,y2 for img1 and target, and x3,x4,y3,y4 for img2
362
+ x1,x2,x3,x4 = max(0,tw), min(w+tw,w), max(0,-tw), min(w-tw,w)
363
+ y1,y2,y3,y4 = max(0,th), min(h+th,h), max(0,-th), min(h-th,h)
364
+
365
+ inputs[0] = inputs[0][y1:y2,x1:x2]
366
+ inputs[1] = inputs[1][y3:y4,x3:x4]
367
+ target = target[y1:y2,x1:x2]
368
+ target[:,:,0] += tw
369
+ target[:,:,1] += th
370
+
371
+ return inputs, target
372
+
373
+
374
+ class RandomColorWarp(object):
375
+ def __init__(self, mean_range=0, std_range=0):
376
+ self.mean_range = mean_range
377
+ self.std_range = std_range
378
+
379
+ def __call__(self, inputs, target):
380
+ random_std = np.random.uniform(-self.std_range, self.std_range, 3)
381
+ random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3)
382
+ random_order = np.random.permutation(3)
383
+
384
+ inputs[0] *= (1 + random_std)
385
+ inputs[0] += random_mean
386
+
387
+ inputs[1] *= (1 + random_std)
388
+ inputs[1] += random_mean
389
+
390
+ inputs[0] = inputs[0][:,:,random_order]
391
+ inputs[1] = inputs[1][:,:,random_order]
392
+
393
+ return inputs, target
libs/losses.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libs.blocks import encoder5
2
+ import torch
3
+ import torchvision
4
+ import torch.nn as nn
5
+ from torch.nn import init
6
+ import torch.nn.functional as F
7
+ from .normalization import get_nonspade_norm_layer
8
+ from .blocks import encoder5
9
+
10
+ import os
11
+ import numpy as np
12
+
13
+ class BaseNetwork(nn.Module):
14
+ def __init__(self):
15
+ super(BaseNetwork, self).__init__()
16
+
17
+ def print_network(self):
18
+ if isinstance(self, list):
19
+ self = self[0]
20
+ num_params = 0
21
+ for param in self.parameters():
22
+ num_params += param.numel()
23
+ print('Network [%s] was created. Total number of parameters: %.1f million. '
24
+ 'To see the architecture, do print(network).'
25
+ % (type(self).__name__, num_params / 1000000))
26
+
27
+ def init_weights(self, init_type='normal', gain=0.02):
28
+ def init_func(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find('BatchNorm2d') != -1:
31
+ if hasattr(m, 'weight') and m.weight is not None:
32
+ init.normal_(m.weight.data, 1.0, gain)
33
+ if hasattr(m, 'bias') and m.bias is not None:
34
+ init.constant_(m.bias.data, 0.0)
35
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
36
+ if init_type == 'normal':
37
+ init.normal_(m.weight.data, 0.0, gain)
38
+ elif init_type == 'xavier':
39
+ init.xavier_normal_(m.weight.data, gain=gain)
40
+ elif init_type == 'xavier_uniform':
41
+ init.xavier_uniform_(m.weight.data, gain=1.0)
42
+ elif init_type == 'kaiming':
43
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
44
+ elif init_type == 'orthogonal':
45
+ init.orthogonal_(m.weight.data, gain=gain)
46
+ elif init_type == 'none': # uses pytorch's default init method
47
+ m.reset_parameters()
48
+ else:
49
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
50
+ if hasattr(m, 'bias') and m.bias is not None:
51
+ init.constant_(m.bias.data, 0.0)
52
+
53
+ self.apply(init_func)
54
+
55
+ # propagate to children
56
+ for m in self.children():
57
+ if hasattr(m, 'init_weights'):
58
+ m.init_weights(init_type, gain)
59
+
60
+ class NLayerDiscriminator(BaseNetwork):
61
+ def __init__(self):
62
+ super().__init__()
63
+
64
+ kw = 4
65
+ padw = int(np.ceil((kw - 1.0) / 2))
66
+ nf = 64
67
+ n_layers_D = 4
68
+ input_nc = 3
69
+
70
+ norm_layer = get_nonspade_norm_layer('spectralinstance')
71
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
72
+ nn.LeakyReLU(0.2, False)]]
73
+
74
+ for n in range(1, n_layers_D):
75
+ nf_prev = nf
76
+ nf = min(nf * 2, 512)
77
+ stride = 1 if n == n_layers_D - 1 else 2
78
+ sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
79
+ stride=stride, padding=padw)),
80
+ nn.LeakyReLU(0.2, False)
81
+ ]]
82
+
83
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
84
+
85
+ # We divide the layers into groups to extract intermediate layer outputs
86
+ for n in range(len(sequence)):
87
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
88
+
89
+ def forward(self, input, get_intermediate_features = True):
90
+ results = [input]
91
+ for submodel in self.children():
92
+ intermediate_output = submodel(results[-1])
93
+ results.append(intermediate_output)
94
+
95
+ if get_intermediate_features:
96
+ return results[1:]
97
+ else:
98
+ return results[-1]
99
+
100
+ class VGG19(torch.nn.Module):
101
+ def __init__(self, requires_grad=False):
102
+ super().__init__()
103
+ vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
104
+ self.slice1 = torch.nn.Sequential()
105
+ self.slice2 = torch.nn.Sequential()
106
+ self.slice3 = torch.nn.Sequential()
107
+ self.slice4 = torch.nn.Sequential()
108
+ self.slice5 = torch.nn.Sequential()
109
+ for x in range(2):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(2, 7):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(7, 12):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(12, 21):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(21, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ import pdb; pdb.set_trace()
120
+ if not requires_grad:
121
+ for param in self.parameters():
122
+ param.requires_grad = False
123
+
124
+ def forward(self, X):
125
+ h_relu1 = self.slice1(X)
126
+ h_relu2 = self.slice2(h_relu1)
127
+ h_relu3 = self.slice3(h_relu2)
128
+ h_relu4 = self.slice4(h_relu3)
129
+ h_relu5 = self.slice5(h_relu4)
130
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
131
+ return out
132
+
133
+ class encoder5(nn.Module):
134
+ def __init__(self):
135
+ super(encoder5,self).__init__()
136
+ # vgg
137
+ # 224 x 224
138
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
139
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
140
+ # 226 x 226
141
+
142
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
143
+ self.relu2 = nn.ReLU(inplace=True)
144
+ # 224 x 224
145
+
146
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
147
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
148
+ self.relu3 = nn.ReLU(inplace=True)
149
+ # 224 x 224
150
+
151
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
152
+ # 112 x 112
153
+
154
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
155
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
156
+ self.relu4 = nn.ReLU(inplace=True)
157
+ # 112 x 112
158
+
159
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
160
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
161
+ self.relu5 = nn.ReLU(inplace=True)
162
+ # 112 x 112
163
+
164
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
165
+ # 56 x 56
166
+
167
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
168
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
169
+ self.relu6 = nn.ReLU(inplace=True)
170
+ # 56 x 56
171
+
172
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
173
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
174
+ self.relu7 = nn.ReLU(inplace=True)
175
+ # 56 x 56
176
+
177
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
178
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
179
+ self.relu8 = nn.ReLU(inplace=True)
180
+ # 56 x 56
181
+
182
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
183
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
184
+ self.relu9 = nn.ReLU(inplace=True)
185
+ # 56 x 56
186
+
187
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
188
+ # 28 x 28
189
+
190
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
191
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
192
+ self.relu10 = nn.ReLU(inplace=True)
193
+
194
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
195
+ self.conv11 = nn.Conv2d(512,512,3,1,0)
196
+ self.relu11 = nn.ReLU(inplace=True)
197
+
198
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
199
+ self.conv12 = nn.Conv2d(512,512,3,1,0)
200
+ self.relu12 = nn.ReLU(inplace=True)
201
+
202
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
203
+ self.conv13 = nn.Conv2d(512,512,3,1,0)
204
+ self.relu13 = nn.ReLU(inplace=True)
205
+
206
+ self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
207
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
208
+ self.conv14 = nn.Conv2d(512,512,3,1,0)
209
+ self.relu14 = nn.ReLU(inplace=True)
210
+
211
+ def forward(self,x):
212
+ output = []
213
+ out = self.conv1(x)
214
+ out = self.reflecPad1(out)
215
+ out = self.conv2(out)
216
+ out = self.relu2(out)
217
+ output.append(out)
218
+
219
+ out = self.reflecPad3(out)
220
+ out = self.conv3(out)
221
+ out = self.relu3(out)
222
+ out = self.maxPool(out)
223
+ out = self.reflecPad4(out)
224
+ out = self.conv4(out)
225
+ out = self.relu4(out)
226
+ output.append(out)
227
+
228
+ out = self.reflecPad5(out)
229
+ out = self.conv5(out)
230
+ out = self.relu5(out)
231
+ out = self.maxPool2(out)
232
+ out = self.reflecPad6(out)
233
+ out = self.conv6(out)
234
+ out = self.relu6(out)
235
+ output.append(out)
236
+
237
+ out = self.reflecPad7(out)
238
+ out = self.conv7(out)
239
+ out = self.relu7(out)
240
+ out = self.reflecPad8(out)
241
+ out = self.conv8(out)
242
+ out = self.relu8(out)
243
+ out = self.reflecPad9(out)
244
+ out = self.conv9(out)
245
+ out = self.relu9(out)
246
+ out = self.maxPool3(out)
247
+ out = self.reflecPad10(out)
248
+ out = self.conv10(out)
249
+ out = self.relu10(out)
250
+ output.append(out)
251
+
252
+ out = self.reflecPad11(out)
253
+ out = self.conv11(out)
254
+ out = self.relu11(out)
255
+ out = self.reflecPad12(out)
256
+ out = self.conv12(out)
257
+ out = self.relu12(out)
258
+ out = self.reflecPad13(out)
259
+ out = self.conv13(out)
260
+ out = self.relu13(out)
261
+ out = self.maxPool4(out)
262
+ out = self.reflecPad14(out)
263
+ out = self.conv14(out)
264
+ out = self.relu14(out)
265
+
266
+ output.append(out)
267
+ return output
268
+
269
+ class VGGLoss(nn.Module):
270
+ def __init__(self, model_path):
271
+ super(VGGLoss, self).__init__()
272
+ self.vgg = encoder5().cuda()
273
+ self.vgg.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
274
+ self.criterion = nn.MSELoss()
275
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
276
+
277
+ def forward(self, x, y):
278
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
279
+ loss = 0
280
+ for i in range(4):
281
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
282
+ return loss
283
+
284
+ class GANLoss(nn.Module):
285
+ def __init__(self, gan_mode = 'hinge', target_real_label=1.0, target_fake_label=0.0,
286
+ tensor=torch.cuda.FloatTensor):
287
+ super(GANLoss, self).__init__()
288
+ self.real_label = target_real_label
289
+ self.fake_label = target_fake_label
290
+ self.real_label_tensor = None
291
+ self.fake_label_tensor = None
292
+ self.zero_tensor = None
293
+ self.Tensor = tensor
294
+ self.gan_mode = gan_mode
295
+ if gan_mode == 'ls':
296
+ pass
297
+ elif gan_mode == 'original':
298
+ pass
299
+ elif gan_mode == 'w':
300
+ pass
301
+ elif gan_mode == 'hinge':
302
+ pass
303
+ else:
304
+ raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
305
+
306
+ def get_target_tensor(self, input, target_is_real):
307
+ if target_is_real:
308
+ if self.real_label_tensor is None:
309
+ self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
310
+ self.real_label_tensor.requires_grad_(False)
311
+ return self.real_label_tensor.expand_as(input)
312
+ else:
313
+ if self.fake_label_tensor is None:
314
+ self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
315
+ self.fake_label_tensor.requires_grad_(False)
316
+ return self.fake_label_tensor.expand_as(input)
317
+
318
+ def get_zero_tensor(self, input):
319
+ if self.zero_tensor is None:
320
+ self.zero_tensor = self.Tensor(1).fill_(0)
321
+ self.zero_tensor.requires_grad_(False)
322
+ return self.zero_tensor.expand_as(input)
323
+
324
+ def loss(self, input, target_is_real, for_discriminator=True):
325
+ if self.gan_mode == 'original': # cross entropy loss
326
+ target_tensor = self.get_target_tensor(input, target_is_real)
327
+ loss = F.binary_cross_entropy_with_logits(input, target_tensor)
328
+ return loss
329
+ elif self.gan_mode == 'ls':
330
+ target_tensor = self.get_target_tensor(input, target_is_real)
331
+ return F.mse_loss(input, target_tensor)
332
+ elif self.gan_mode == 'hinge':
333
+ if for_discriminator:
334
+ if target_is_real:
335
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
336
+ loss = -torch.mean(minval)
337
+ else:
338
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
339
+ loss = -torch.mean(minval)
340
+ else:
341
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
342
+ loss = -torch.mean(input)
343
+ return loss
344
+ else:
345
+ # wgan
346
+ if target_is_real:
347
+ return -input.mean()
348
+ else:
349
+ return input.mean()
350
+
351
+ def __call__(self, input, target_is_real, for_discriminator=True):
352
+ # computing loss is a bit complicated because |input| may not be
353
+ # a tensor, but list of tensors in case of multiscale discriminator
354
+ if isinstance(input, list):
355
+ loss = 0
356
+ for pred_i in input:
357
+ if isinstance(pred_i, list):
358
+ pred_i = pred_i[-1]
359
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
360
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
361
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
362
+ loss += new_loss
363
+ return loss / len(input)
364
+ else:
365
+ return self.loss(input, target_is_real, for_discriminator)
366
+
367
+ class SPADE_LOSS(nn.Module):
368
+ def __init__(self, model_path, lambda_feat = 1):
369
+ super(SPADE_LOSS, self).__init__()
370
+ self.criterionVGG = VGGLoss(model_path)
371
+ self.criterionGAN = GANLoss('hinge')
372
+ self.criterionL1 = nn.L1Loss()
373
+ self.discriminator = NLayerDiscriminator()
374
+ self.lambda_feat = lambda_feat
375
+
376
+ def forward(self, x, y, for_discriminator = False):
377
+ pred_real = self.discriminator(y)
378
+ if not for_discriminator:
379
+ pred_fake = self.discriminator(x)
380
+ VGGLoss = self.criterionVGG(x, y)
381
+ GANLoss = self.criterionGAN(pred_fake, True, for_discriminator = False)
382
+
383
+ # feature matching loss
384
+ # last output is the final prediction, so we exclude it
385
+ num_intermediate_outputs = len(pred_fake) - 1
386
+ GAN_Feat_loss = 0
387
+ for j in range(num_intermediate_outputs): # for each layer output
388
+ unweighted_loss = self.criterionL1(pred_fake[j], pred_real[j].detach())
389
+ GAN_Feat_loss += unweighted_loss * self.lambda_feat
390
+ L1Loss = self.criterionL1(x, y)
391
+ return VGGLoss, GANLoss, GAN_Feat_loss, L1Loss
392
+ else:
393
+ pred_fake = self.discriminator(x.detach())
394
+ GANLoss = self.criterionGAN(pred_fake, False, for_discriminator = True)
395
+ GANLoss += self.criterionGAN(pred_real, True, for_discriminator = True)
396
+ return GANLoss
397
+
398
+ class ContrastiveLoss(nn.Module):
399
+ """
400
+ Contrastive loss
401
+ Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
402
+ """
403
+
404
+ def __init__(self, margin):
405
+ super(ContrastiveLoss, self).__init__()
406
+ self.margin = margin
407
+ self.eps = 1e-9
408
+
409
+ def forward(self, out1, out2, target, size_average=True, norm = True):
410
+ if norm:
411
+ output1 = out1 / out1.pow(2).sum(1, keepdim=True).sqrt()
412
+ output2 = out1 / out2.pow(2).sum(1, keepdim=True).sqrt()
413
+ distances = (output2 - output1).pow(2).sum(1) # squared distances
414
+ losses = 0.5 * (target.float() * distances +
415
+ (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
416
+ return losses.mean() if size_average else losses.sum()