diff --git a/data/___init__.py b/data/___init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/color150.mat b/data/color150.mat new file mode 100644 index 0000000000000000000000000000000000000000..c518b64fbbe899d4a8b2705f012eeba795339892 Binary files /dev/null and b/data/color150.mat differ diff --git a/data/images/108073.jpg b/data/images/108073.jpg new file mode 100755 index 0000000000000000000000000000000000000000..b4a23e5c18c7120d2cfc11996e2ce87fa6f04389 Binary files /dev/null and b/data/images/108073.jpg differ diff --git a/data/images/12003.jpg b/data/images/12003.jpg new file mode 100755 index 0000000000000000000000000000000000000000..2957db803259e68332440992b0c53ec5e9faac6a Binary files /dev/null and b/data/images/12003.jpg differ diff --git a/data/images/12074.jpg b/data/images/12074.jpg new file mode 100755 index 0000000000000000000000000000000000000000..4ee984b2ccc31fe890dff961e9e0248b3aca8676 Binary files /dev/null and b/data/images/12074.jpg differ diff --git a/data/images/134008.jpg b/data/images/134008.jpg new file mode 100755 index 0000000000000000000000000000000000000000..234fc1552fb1974c3989b4b20d91520d1196a099 Binary files /dev/null and b/data/images/134008.jpg differ diff --git a/data/images/134052.jpg b/data/images/134052.jpg new file mode 100755 index 0000000000000000000000000000000000000000..950cff04b426593827bac4582eb4a4533192c801 Binary files /dev/null and b/data/images/134052.jpg differ diff --git a/data/images/138032.jpg b/data/images/138032.jpg new file mode 100755 index 0000000000000000000000000000000000000000..cd876f53920057596fa73a4fdfb126b487dd4303 Binary files /dev/null and b/data/images/138032.jpg differ diff --git a/data/images/145053.jpg b/data/images/145053.jpg new file mode 100755 index 0000000000000000000000000000000000000000..408106f91e93c7050df832feefd35eb44d4e9520 Binary files /dev/null and b/data/images/145053.jpg differ diff --git a/data/images/164074.jpg b/data/images/164074.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1160d1caadae0c9f6aebba68edcd8fcdf16fedd6 Binary files /dev/null and b/data/images/164074.jpg differ diff --git a/data/images/169012.jpg b/data/images/169012.jpg new file mode 100755 index 0000000000000000000000000000000000000000..d2bf33f4a2189ea61a5fc268d0e978777a8fab3b Binary files /dev/null and b/data/images/169012.jpg differ diff --git a/data/images/198023.jpg b/data/images/198023.jpg new file mode 100755 index 0000000000000000000000000000000000000000..fcdfa7ae05302d284beee7d53c70d8f63b4daa61 Binary files /dev/null and b/data/images/198023.jpg differ diff --git a/data/images/25098.jpg b/data/images/25098.jpg new file mode 100755 index 0000000000000000000000000000000000000000..db3519e0a6ce14ebcc2953b25d33242f8e272806 Binary files /dev/null and b/data/images/25098.jpg differ diff --git a/data/images/277095.jpg b/data/images/277095.jpg new file mode 100755 index 0000000000000000000000000000000000000000..879f5fb89c41a5225e062e13d64364feca73a289 Binary files /dev/null and b/data/images/277095.jpg differ diff --git a/data/images/45077.jpg b/data/images/45077.jpg new file mode 100755 index 0000000000000000000000000000000000000000..98ff1e6fc3180e44f6fd1d450bd912dd43eee8b8 Binary files /dev/null and b/data/images/45077.jpg differ diff --git a/data/palette.txt b/data/palette.txt new file mode 100644 index 0000000000000000000000000000000000000000..691ee4b07c28d4cdd4e0e4023bbef1b94177bc60 --- /dev/null +++ b/data/palette.txt @@ -0,0 +1,256 @@ +0 0 0 +128 0 0 +0 128 0 +128 128 0 +0 0 128 +128 0 128 +0 128 128 +128 128 128 +64 0 0 +191 0 0 +64 128 0 +191 128 0 +64 0 128 +191 0 128 +64 128 128 +191 128 128 +0 64 0 +128 64 0 +0 191 0 +128 191 0 +0 64 128 +128 64 128 +22 22 22 +23 23 23 +24 24 24 +25 25 25 +26 26 26 +27 27 27 +28 28 28 +29 29 29 +30 30 30 +31 31 31 +32 32 32 +33 33 33 +34 34 34 +35 35 35 +36 36 36 +37 37 37 +38 38 38 +39 39 39 +40 40 40 +41 41 41 +42 42 42 +43 43 43 +44 44 44 +45 45 45 +46 46 46 +47 47 47 +48 48 48 +49 49 49 +50 50 50 +51 51 51 +52 52 52 +53 53 53 +54 54 54 +55 55 55 +56 56 56 +57 57 57 +58 58 58 +59 59 59 +60 60 60 +61 61 61 +62 62 62 +63 63 63 +64 64 64 +65 65 65 +66 66 66 +67 67 67 +68 68 68 +69 69 69 +70 70 70 +71 71 71 +72 72 72 +73 73 73 +74 74 74 +75 75 75 +76 76 76 +77 77 77 +78 78 78 +79 79 79 +80 80 80 +81 81 81 +82 82 82 +83 83 83 +84 84 84 +85 85 85 +86 86 86 +87 87 87 +88 88 88 +89 89 89 +90 90 90 +91 91 91 +92 92 92 +93 93 93 +94 94 94 +95 95 95 +96 96 96 +97 97 97 +98 98 98 +99 99 99 +100 100 100 +101 101 101 +102 102 102 +103 103 103 +104 104 104 +105 105 105 +106 106 106 +107 107 107 +108 108 108 +109 109 109 +110 110 110 +111 111 111 +112 112 112 +113 113 113 +114 114 114 +115 115 115 +116 116 116 +117 117 117 +118 118 118 +119 119 119 +120 120 120 +121 121 121 +122 122 122 +123 123 123 +124 124 124 +125 125 125 +126 126 126 +127 127 127 +128 128 128 +129 129 129 +130 130 130 +131 131 131 +132 132 132 +133 133 133 +134 134 134 +135 135 135 +136 136 136 +137 137 137 +138 138 138 +139 139 139 +140 140 140 +141 141 141 +142 142 142 +143 143 143 +144 144 144 +145 145 145 +146 146 146 +147 147 147 +148 148 148 +149 149 149 +150 150 150 +151 151 151 +152 152 152 +153 153 153 +154 154 154 +155 155 155 +156 156 156 +157 157 157 +158 158 158 +159 159 159 +160 160 160 +161 161 161 +162 162 162 +163 163 163 +164 164 164 +165 165 165 +166 166 166 +167 167 167 +168 168 168 +169 169 169 +170 170 170 +171 171 171 +172 172 172 +173 173 173 +174 174 174 +175 175 175 +176 176 176 +177 177 177 +178 178 178 +179 179 179 +180 180 180 +181 181 181 +182 182 182 +183 183 183 +184 184 184 +185 185 185 +186 186 186 +187 187 187 +188 188 188 +189 189 189 +190 190 190 +191 191 191 +192 192 192 +193 193 193 +194 194 194 +195 195 195 +196 196 196 +197 197 197 +198 198 198 +199 199 199 +200 200 200 +201 201 201 +202 202 202 +203 203 203 +204 204 204 +205 205 205 +206 206 206 +207 207 207 +208 208 208 +209 209 209 +210 210 210 +211 211 211 +212 212 212 +213 213 213 +214 214 214 +215 215 215 +216 216 216 +217 217 217 +218 218 218 +219 219 219 +220 220 220 +221 221 221 +222 222 222 +223 223 223 +224 224 224 +225 225 225 +226 226 226 +227 227 227 +228 228 228 +229 229 229 +230 230 230 +231 231 231 +232 232 232 +233 233 233 +234 234 234 +235 235 235 +236 236 236 +237 237 237 +238 238 238 +239 239 239 +240 240 240 +241 241 241 +242 242 242 +243 243 243 +244 244 244 +245 245 245 +246 246 246 +247 247 247 +248 248 248 +249 249 249 +250 250 250 +251 251 251 +252 252 252 +253 253 253 +254 254 254 +255 255 255 \ No newline at end of file diff --git a/data/test_images/100039.jpg b/data/test_images/100039.jpg new file mode 100755 index 0000000000000000000000000000000000000000..96943e0a10bc3312ab54e50b167493fd61df47a3 Binary files /dev/null and b/data/test_images/100039.jpg differ diff --git a/data/test_images/108004.jpg b/data/test_images/108004.jpg new file mode 100755 index 0000000000000000000000000000000000000000..8bd18d4d03a21c49415e9e35fa53edb6fe2370f3 Binary files /dev/null and b/data/test_images/108004.jpg differ diff --git a/data/test_images/130014.jpg b/data/test_images/130014.jpg new file mode 100755 index 0000000000000000000000000000000000000000..91c64abb2ba5eccd4a182a1483918dc16c498939 Binary files /dev/null and b/data/test_images/130014.jpg differ diff --git a/data/test_images/130066.jpg b/data/test_images/130066.jpg new file mode 100755 index 0000000000000000000000000000000000000000..93bfa47bd0b6d169520b4c243dc61f2b02322856 Binary files /dev/null and b/data/test_images/130066.jpg differ diff --git a/data/test_images/16068.jpg b/data/test_images/16068.jpg new file mode 100755 index 0000000000000000000000000000000000000000..2017deaa09810b9a6da5c61fcddb5eb3c356334b Binary files /dev/null and b/data/test_images/16068.jpg differ diff --git a/data/test_images/2018.jpg b/data/test_images/2018.jpg new file mode 100755 index 0000000000000000000000000000000000000000..19031dbe81e165e0c71b7c2e892b42c06de3b025 Binary files /dev/null and b/data/test_images/2018.jpg differ diff --git a/data/test_images/208078.jpg b/data/test_images/208078.jpg new file mode 100755 index 0000000000000000000000000000000000000000..eed41933c8766c4c296dce5602ce123376b52f15 Binary files /dev/null and b/data/test_images/208078.jpg differ diff --git a/data/test_images/223060.jpg b/data/test_images/223060.jpg new file mode 100755 index 0000000000000000000000000000000000000000..690053020784601009e05b8844cdddf437fa1e7b Binary files /dev/null and b/data/test_images/223060.jpg differ diff --git a/data/test_images/226033.jpg b/data/test_images/226033.jpg new file mode 100755 index 0000000000000000000000000000000000000000..0080a2ccc21f22dd86ac55e69f3ea7e34cebd9ef Binary files /dev/null and b/data/test_images/226033.jpg differ diff --git a/data/test_images/388006.jpg b/data/test_images/388006.jpg new file mode 100755 index 0000000000000000000000000000000000000000..49d70f0f37c5f8b1435ac1f25aa6946dda2bef89 Binary files /dev/null and b/data/test_images/388006.jpg differ diff --git a/data/test_images/78098.jpg b/data/test_images/78098.jpg new file mode 100755 index 0000000000000000000000000000000000000000..8b3cb4d82fce65f184a3504175453d5a898c6da7 Binary files /dev/null and b/data/test_images/78098.jpg differ diff --git a/libs/__init__.py b/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/libs/__pycache__/__init__.cpython-37.pyc b/libs/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1ba108cf6f079c1ad1ab58c610ca7d37030b9a7 Binary files /dev/null and b/libs/__pycache__/__init__.cpython-37.pyc differ diff --git a/libs/__pycache__/__init__.cpython-38.pyc b/libs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5547802aa56b6ddc43ce38a17643670d07df2547 Binary files /dev/null and b/libs/__pycache__/__init__.cpython-38.pyc differ diff --git a/libs/__pycache__/flow_transforms.cpython-37.pyc b/libs/__pycache__/flow_transforms.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71aba0ef81e84cea0c6cb7b3b08614f55dd79809 Binary files /dev/null and b/libs/__pycache__/flow_transforms.cpython-37.pyc differ diff --git a/libs/__pycache__/flow_transforms.cpython-38.pyc b/libs/__pycache__/flow_transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..999e544c09df0cdf495bd6f383f6a34200b7ba57 Binary files /dev/null and b/libs/__pycache__/flow_transforms.cpython-38.pyc differ diff --git a/libs/__pycache__/nnutils.cpython-37.pyc b/libs/__pycache__/nnutils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a94f7612f9a3687307f1f0fb2c89304e71bf630 Binary files /dev/null and b/libs/__pycache__/nnutils.cpython-37.pyc differ diff --git a/libs/__pycache__/nnutils.cpython-38.pyc b/libs/__pycache__/nnutils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c80aa875b5fb0e02863ce0441f40f4f07483a5 Binary files /dev/null and b/libs/__pycache__/nnutils.cpython-38.pyc differ diff --git a/libs/__pycache__/options.cpython-37.pyc b/libs/__pycache__/options.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59683b20a1ccad1cb2385de0e692daa047d04eb3 Binary files /dev/null and b/libs/__pycache__/options.cpython-37.pyc differ diff --git a/libs/__pycache__/options.cpython-38.pyc b/libs/__pycache__/options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..037f3fa980801f07ea576d1c2c9015110529d80e Binary files /dev/null and b/libs/__pycache__/options.cpython-38.pyc differ diff --git a/libs/__pycache__/test_base.cpython-37.pyc b/libs/__pycache__/test_base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea23641ce25be6dfa1988de03beb98b2d734fa8 Binary files /dev/null and b/libs/__pycache__/test_base.cpython-37.pyc differ diff --git a/libs/__pycache__/test_base.cpython-38.pyc b/libs/__pycache__/test_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0b3b8f46dc32a88491ed9df8633a61beb9fe44d Binary files /dev/null and b/libs/__pycache__/test_base.cpython-38.pyc differ diff --git a/libs/__pycache__/utils.cpython-37.pyc b/libs/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e5c63e53ff7120bd72c48be4c26308221f1f91 Binary files /dev/null and b/libs/__pycache__/utils.cpython-37.pyc differ diff --git a/libs/__pycache__/utils.cpython-38.pyc b/libs/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba6c283f41bd9212dc03340662b7ecc091c2877 Binary files /dev/null and b/libs/__pycache__/utils.cpython-38.pyc differ diff --git a/libs/blocks.py b/libs/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca94fd47c45c7a81b0fabcc7c3bfa1e12b27d67 --- /dev/null +++ b/libs/blocks.py @@ -0,0 +1,739 @@ +"""Network Modules + - encoder3: vgg encoder up to relu31 + - decoder3: mirror decoder to encoder3 + - encoder4: vgg encoder up to relu41 + - decoder4: mirror decoder to encoder4 + - encoder5: vgg encoder up to relu51 + - styleLoss: gram matrix loss for all style layers + - styleLossMask: gram matrix loss for all style layers, compare between each part defined by a mask + - GramMatrix: compute gram matrix for one layer + - LossCriterion: style transfer loss that include both content & style losses + - LossCriterionMask: style transfer loss that include both content & style losses, use the styleLossMask + - VQEmbedding: codebook class for VQVAE +""" +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from .vq_functions import vq, vq_st +from collections import OrderedDict + +class MetaModule(nn.Module): + """ + Base class for PyTorch meta-learning modules. These modules accept an + additional argument `params` in their `forward` method. + + Notes + ----- + Objects inherited from `MetaModule` are fully compatible with PyTorch + modules from `torch.nn.Module`. The argument `params` is a dictionary of + tensors, with full support of the computation graph (for differentiation). + """ + def meta_named_parameters(self, prefix='', recurse=True): + gen = self._named_members( + lambda module: module._parameters.items() + if isinstance(module, MetaModule) else [], + prefix=prefix, recurse=recurse) + for elem in gen: + yield elem + + def meta_parameters(self, recurse=True): + for name, param in self.meta_named_parameters(recurse=recurse): + yield param + +class BatchLinear(nn.Linear, MetaModule): + '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a + hypernetwork.''' + __doc__ = nn.Linear.__doc__ + + def forward(self, input, params=None): + if params is None: + params = OrderedDict(self.named_parameters()) + + bias = params.get('bias', None) + weight = params['weight'] + + output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) + output += bias.unsqueeze(-2) + return output + +class decoder1(nn.Module): + def __init__(self): + super(decoder1,self).__init__() + self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1)) + # 226 x 226 + self.conv3 = nn.Conv2d(64,3,3,1,0) + # 224 x 224 + + def forward(self,x): + out = self.reflecPad2(x) + out = self.conv3(out) + return out + + +class decoder2(nn.Module): + def __init__(self): + super(decoder2,self).__init__() + # decoder + self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) + self.conv5 = nn.Conv2d(128,64,3,1,0) + self.relu5 = nn.ReLU(inplace=True) + # 112 x 112 + + self.unpool = nn.UpsamplingNearest2d(scale_factor=2) + # 224 x 224 + + self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) + self.conv6 = nn.Conv2d(64,64,3,1,0) + self.relu6 = nn.ReLU(inplace=True) + # 224 x 224 + + self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) + self.conv7 = nn.Conv2d(64,3,3,1,0) + + def forward(self,x): + out = self.reflecPad5(x) + out = self.conv5(out) + out = self.relu5(out) + out = self.unpool(out) + out = self.reflecPad6(out) + out = self.conv6(out) + out = self.relu6(out) + out = self.reflecPad7(out) + out = self.conv7(out) + return out + +class encoder3(nn.Module): + def __init__(self): + super(encoder3,self).__init__() + # vgg + # 224 x 224 + self.conv1 = nn.Conv2d(3,3,1,1,0) + self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) + # 226 x 226 + + self.conv2 = nn.Conv2d(3,64,3,1,0) + self.relu2 = nn.ReLU(inplace=True) + # 224 x 224 + + self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) + self.conv3 = nn.Conv2d(64,64,3,1,0) + self.relu3 = nn.ReLU(inplace=True) + # 224 x 224 + + self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) + # 112 x 112 + + self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) + self.conv4 = nn.Conv2d(64,128,3,1,0) + self.relu4 = nn.ReLU(inplace=True) + # 112 x 112 + + self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) + self.conv5 = nn.Conv2d(128,128,3,1,0) + self.relu5 = nn.ReLU(inplace=True) + # 112 x 112 + + self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) + # 56 x 56 + + self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) + self.conv6 = nn.Conv2d(128,256,3,1,0) + self.relu6 = nn.ReLU(inplace=True) + # 56 x 56 + def forward(self,x): + out = self.conv1(x) + out = self.reflecPad1(out) + out = self.conv2(out) + out = self.relu2(out) + out = self.reflecPad3(out) + out = self.conv3(out) + pool1 = self.relu3(out) + out,pool_idx = self.maxPool(pool1) + out = self.reflecPad4(out) + out = self.conv4(out) + out = self.relu4(out) + out = self.reflecPad5(out) + out = self.conv5(out) + pool2 = self.relu5(out) + out,pool_idx2 = self.maxPool2(pool2) + out = self.reflecPad6(out) + out = self.conv6(out) + out = self.relu6(out) + return out + +class decoder3(nn.Module): + def __init__(self): + super(decoder3,self).__init__() + # decoder + self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) + self.conv7 = nn.Conv2d(256,128,3,1,0) + self.relu7 = nn.ReLU(inplace=True) + # 56 x 56 + + self.unpool = nn.UpsamplingNearest2d(scale_factor=2) + # 112 x 112 + + self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) + self.conv8 = nn.Conv2d(128,128,3,1,0) + self.relu8 = nn.ReLU(inplace=True) + # 112 x 112 + + self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) + self.conv9 = nn.Conv2d(128,64,3,1,0) + self.relu9 = nn.ReLU(inplace=True) + + self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) + # 224 x 224 + + self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) + self.conv10 = nn.Conv2d(64,64,3,1,0) + self.relu10 = nn.ReLU(inplace=True) + + self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) + self.conv11 = nn.Conv2d(64,3,3,1,0) + + def forward(self,x): + output = {} + out = self.reflecPad7(x) + out = self.conv7(out) + out = self.relu7(out) + out = self.unpool(out) + out = self.reflecPad8(out) + out = self.conv8(out) + out = self.relu8(out) + out = self.reflecPad9(out) + out = self.conv9(out) + out_relu9 = self.relu9(out) + out = self.unpool2(out_relu9) + out = self.reflecPad10(out) + out = self.conv10(out) + out = self.relu10(out) + out = self.reflecPad11(out) + out = self.conv11(out) + return out + +class encoder4(nn.Module): + def __init__(self): + super(encoder4,self).__init__() + # vgg + # 224 x 224 + self.conv1 = nn.Conv2d(3,3,1,1,0) + self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) + # 226 x 226 + + self.conv2 = nn.Conv2d(3,64,3,1,0) + self.relu2 = nn.ReLU(inplace=True) + # 224 x 224 + + self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) + self.conv3 = nn.Conv2d(64,64,3,1,0) + self.relu3 = nn.ReLU(inplace=True) + # 224 x 224 + + self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) + # 112 x 112 + + self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) + self.conv4 = nn.Conv2d(64,128,3,1,0) + self.relu4 = nn.ReLU(inplace=True) + # 112 x 112 + + self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) + self.conv5 = nn.Conv2d(128,128,3,1,0) + self.relu5 = nn.ReLU(inplace=True) + # 112 x 112 + + self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) + # 56 x 56 + + self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) + self.conv6 = nn.Conv2d(128,256,3,1,0) + self.relu6 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) + self.conv7 = nn.Conv2d(256,256,3,1,0) + self.relu7 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) + self.conv8 = nn.Conv2d(256,256,3,1,0) + self.relu8 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) + self.conv9 = nn.Conv2d(256,256,3,1,0) + self.relu9 = nn.ReLU(inplace=True) + # 56 x 56 + + self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) + # 28 x 28 + + self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) + self.conv10 = nn.Conv2d(256,512,3,1,0) + self.relu10 = nn.ReLU(inplace=True) + # 28 x 28 + + def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None): + output = {} + out = self.conv1(x) + out = self.reflecPad1(out) + out = self.conv2(out) + output['r11'] = self.relu2(out) + out = self.reflecPad7(output['r11']) + + out = self.conv3(out) + output['r12'] = self.relu3(out) + + output['p1'] = self.maxPool(output['r12']) + out = self.reflecPad4(output['p1']) + out = self.conv4(out) + output['r21'] = self.relu4(out) + out = self.reflecPad7(output['r21']) + + out = self.conv5(out) + output['r22'] = self.relu5(out) + + output['p2'] = self.maxPool2(output['r22']) + out = self.reflecPad6(output['p2']) + out = self.conv6(out) + output['r31'] = self.relu6(out) + if(matrix31 is not None): + feature3,transmatrix3 = matrix31(output['r31'],sF['r31']) + out = self.reflecPad7(feature3) + else: + out = self.reflecPad7(output['r31']) + out = self.conv7(out) + output['r32'] = self.relu7(out) + + out = self.reflecPad8(output['r32']) + out = self.conv8(out) + output['r33'] = self.relu8(out) + + out = self.reflecPad9(output['r33']) + out = self.conv9(out) + output['r34'] = self.relu9(out) + + output['p3'] = self.maxPool3(output['r34']) + out = self.reflecPad10(output['p3']) + out = self.conv10(out) + output['r41'] = self.relu10(out) + + return output + +class decoder4(nn.Module): + def __init__(self): + super(decoder4,self).__init__() + # decoder + self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) + self.conv11 = nn.Conv2d(512,256,3,1,0) + self.relu11 = nn.ReLU(inplace=True) + # 28 x 28 + + self.unpool = nn.UpsamplingNearest2d(scale_factor=2) + # 56 x 56 + + self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) + self.conv12 = nn.Conv2d(256,256,3,1,0) + self.relu12 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) + self.conv13 = nn.Conv2d(256,256,3,1,0) + self.relu13 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) + self.conv14 = nn.Conv2d(256,256,3,1,0) + self.relu14 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) + self.conv15 = nn.Conv2d(256,128,3,1,0) + self.relu15 = nn.ReLU(inplace=True) + # 56 x 56 + + self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) + # 112 x 112 + + self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) + self.conv16 = nn.Conv2d(128,128,3,1,0) + self.relu16 = nn.ReLU(inplace=True) + # 112 x 112 + + self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) + self.conv17 = nn.Conv2d(128,64,3,1,0) + self.relu17 = nn.ReLU(inplace=True) + # 112 x 112 + + self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) + # 224 x 224 + + self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) + self.conv18 = nn.Conv2d(64,64,3,1,0) + self.relu18 = nn.ReLU(inplace=True) + # 224 x 224 + + self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) + self.conv19 = nn.Conv2d(64,3,3,1,0) + + def forward(self,x): + # decoder + out = self.reflecPad11(x) + out = self.conv11(out) + out = self.relu11(out) + out = self.unpool(out) + out = self.reflecPad12(out) + out = self.conv12(out) + + out = self.relu12(out) + out = self.reflecPad13(out) + out = self.conv13(out) + out = self.relu13(out) + out = self.reflecPad14(out) + out = self.conv14(out) + out = self.relu14(out) + out = self.reflecPad15(out) + out = self.conv15(out) + out = self.relu15(out) + out = self.unpool2(out) + out = self.reflecPad16(out) + out = self.conv16(out) + out = self.relu16(out) + out = self.reflecPad17(out) + out = self.conv17(out) + out = self.relu17(out) + out = self.unpool3(out) + out = self.reflecPad18(out) + out = self.conv18(out) + out = self.relu18(out) + out = self.reflecPad19(out) + out = self.conv19(out) + return out + +class encoder5(nn.Module): + def __init__(self): + super(encoder5,self).__init__() + # vgg + # 224 x 224 + self.conv1 = nn.Conv2d(3,3,1,1,0) + self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) + # 226 x 226 + + self.conv2 = nn.Conv2d(3,64,3,1,0) + self.relu2 = nn.ReLU(inplace=True) + # 224 x 224 + + self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) + self.conv3 = nn.Conv2d(64,64,3,1,0) + self.relu3 = nn.ReLU(inplace=True) + # 224 x 224 + + self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) + # 112 x 112 + + self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) + self.conv4 = nn.Conv2d(64,128,3,1,0) + self.relu4 = nn.ReLU(inplace=True) + # 112 x 112 + + self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) + self.conv5 = nn.Conv2d(128,128,3,1,0) + self.relu5 = nn.ReLU(inplace=True) + # 112 x 112 + + self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) + # 56 x 56 + + self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) + self.conv6 = nn.Conv2d(128,256,3,1,0) + self.relu6 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) + self.conv7 = nn.Conv2d(256,256,3,1,0) + self.relu7 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) + self.conv8 = nn.Conv2d(256,256,3,1,0) + self.relu8 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) + self.conv9 = nn.Conv2d(256,256,3,1,0) + self.relu9 = nn.ReLU(inplace=True) + # 56 x 56 + + self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) + # 28 x 28 + + self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) + self.conv10 = nn.Conv2d(256,512,3,1,0) + self.relu10 = nn.ReLU(inplace=True) + + self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) + self.conv11 = nn.Conv2d(512,512,3,1,0) + self.relu11 = nn.ReLU(inplace=True) + + self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) + self.conv12 = nn.Conv2d(512,512,3,1,0) + self.relu12 = nn.ReLU(inplace=True) + + self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) + self.conv13 = nn.Conv2d(512,512,3,1,0) + self.relu13 = nn.ReLU(inplace=True) + + self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) + self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) + self.conv14 = nn.Conv2d(512,512,3,1,0) + self.relu14 = nn.ReLU(inplace=True) + + def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None): + output = {} + out = self.conv1(x) + out = self.reflecPad1(out) + out = self.conv2(out) + output['r11'] = self.relu2(out) + out = self.reflecPad7(output['r11']) + + #out = self.reflecPad3(output['r11']) + out = self.conv3(out) + output['r12'] = self.relu3(out) + + output['p1'] = self.maxPool(output['r12']) + out = self.reflecPad4(output['p1']) + out = self.conv4(out) + output['r21'] = self.relu4(out) + out = self.reflecPad7(output['r21']) + + #out = self.reflecPad5(output['r21']) + out = self.conv5(out) + output['r22'] = self.relu5(out) + + output['p2'] = self.maxPool2(output['r22']) + out = self.reflecPad6(output['p2']) + out = self.conv6(out) + output['r31'] = self.relu6(out) + if(styleV256 is not None): + feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256) + out = self.reflecPad7(feature) + else: + out = self.reflecPad7(output['r31']) + out = self.conv7(out) + output['r32'] = self.relu7(out) + + out = self.reflecPad8(output['r32']) + out = self.conv8(out) + output['r33'] = self.relu8(out) + + out = self.reflecPad9(output['r33']) + out = self.conv9(out) + output['r34'] = self.relu9(out) + + output['p3'] = self.maxPool3(output['r34']) + out = self.reflecPad10(output['p3']) + out = self.conv10(out) + output['r41'] = self.relu10(out) + + out = self.reflecPad11(out) + out = self.conv11(out) + out = self.relu11(out) + out = self.reflecPad12(out) + out = self.conv12(out) + out = self.relu12(out) + out = self.reflecPad13(out) + out = self.conv13(out) + out = self.relu13(out) + out = self.maxPool4(out) + out = self.reflecPad14(out) + out = self.conv14(out) + out = self.relu14(out) + output['r51'] = out + return output + +class styleLoss(nn.Module): + def forward(self, input, target): + ib,ic,ih,iw = input.size() + iF = input.view(ib,ic,-1) + iMean = torch.mean(iF,dim=2) + iCov = GramMatrix()(input) + + tb,tc,th,tw = target.size() + tF = target.view(tb,tc,-1) + tMean = torch.mean(tF,dim=2) + tCov = GramMatrix()(target) + + loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov) + return loss/tb + +class GramMatrix(nn.Module): + def forward(self, input): + b, c, h, w = input.size() + f = input.view(b,c,h*w) # bxcx(hxw) + # torch.bmm(batch1, batch2, out=None) # + # batch1: bxmxp, batch2: bxpxn -> bxmxn # + G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + return G.div_(c*h*w) + +class LossCriterion(nn.Module): + def __init__(self, style_layers, content_layers, style_weight, content_weight, + model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'): + super(LossCriterion,self).__init__() + + self.style_layers = style_layers + self.content_layers = content_layers + self.style_weight = style_weight + self.content_weight = content_weight + + self.styleLosses = [styleLoss()] * len(style_layers) + self.contentLosses = [nn.MSELoss()] * len(content_layers) + + self.vgg5 = encoder5() + self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) + + for param in self.vgg5.parameters(): + param.requires_grad = True + + def forward(self, transfer, image, content=True, style=True): + cF = self.vgg5(image) + sF = self.vgg5(image) + tF = self.vgg5(transfer) + + losses = {} + + # content loss + if content: + totalContentLoss = 0 + for i,layer in enumerate(self.content_layers): + cf_i = cF[layer] + cf_i = cf_i.detach() + tf_i = tF[layer] + loss_i = self.contentLosses[i] + totalContentLoss += loss_i(tf_i,cf_i) + totalContentLoss = totalContentLoss * self.content_weight + losses['content'] = totalContentLoss + + # style loss + if style: + totalStyleLoss = 0 + for i,layer in enumerate(self.style_layers): + sf_i = sF[layer] + sf_i = sf_i.detach() + tf_i = tF[layer] + loss_i = self.styleLosses[i] + totalStyleLoss += loss_i(tf_i,sf_i) + totalStyleLoss = totalStyleLoss * self.style_weight + losses['style'] = totalStyleLoss + + return losses + +class styleLossMask(nn.Module): + def forward(self, input, target, mask): + ib,ic,ih,iw = input.size() + iF = input.view(ib,ic,-1) + tb,tc,th,tw = target.size() + tF = target.view(tb,tc,-1) + + loss = 0 + mb, mc, mh, mw = mask.shape + for i in range(mb): + # resize mask to have the same size of the feature + maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest') + mask_flat = maski.view(mc, -1) + for j in range(mc): + # get features for each part + idx = torch.nonzero(mask_flat[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + ipart = torch.index_select(iF, 2, idx) + tpart = torch.index_select(tF, 2, idx) + + iMean = torch.mean(ipart,dim=2) + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + tMean = torch.mean(tpart,dim=2) + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + loss += nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram) + return loss/tb + +class LossCriterionMask(nn.Module): + def __init__(self, style_layers, content_layers, style_weight, content_weight, + model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'): + super(LossCriterionMask,self).__init__() + + self.style_layers = style_layers + self.content_layers = content_layers + self.style_weight = style_weight + self.content_weight = content_weight + + self.styleLosses = [styleLossMask()] * len(style_layers) + self.contentLosses = [nn.MSELoss()] * len(content_layers) + + self.vgg5 = encoder5() + self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) + + for param in self.vgg5.parameters(): + param.requires_grad = True + + def forward(self, transfer, image, mask, content=True, style=True): + # mask: B, N, H, W + cF = self.vgg5(image) + sF = self.vgg5(image) + tF = self.vgg5(transfer) + + losses = {} + + # content loss + if content: + totalContentLoss = 0 + for i,layer in enumerate(self.content_layers): + cf_i = cF[layer] + cf_i = cf_i.detach() + tf_i = tF[layer] + loss_i = self.contentLosses[i] + totalContentLoss += loss_i(tf_i,cf_i) + totalContentLoss = totalContentLoss * self.content_weight + losses['content'] = totalContentLoss + + # style loss + if style: + totalStyleLoss = 0 + for i,layer in enumerate(self.style_layers): + sf_i = sF[layer] + sf_i = sf_i.detach() + tf_i = tF[layer] + loss_i = self.styleLosses[i] + totalStyleLoss += loss_i(tf_i,sf_i, mask) + totalStyleLoss = totalStyleLoss * self.style_weight + losses['style'] = totalStyleLoss + + return losses + +class VQEmbedding(nn.Module): + def __init__(self, K, D): + super().__init__() + self.embedding = nn.Embedding(K, D) + self.embedding.weight.data.uniform_(-1./K, 1./K) + + def forward(self, z_e_x): + z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() + latents = vq(z_e_x_, self.embedding.weight) + return latents + + def straight_through(self, z_e_x, return_index=False): + z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() + z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach()) + z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() + + z_q_x_bar_flatten = torch.index_select(self.embedding.weight, + dim=0, index=indices) + z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) + z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous() + + if return_index: + return z_q_x, z_q_x_bar, indices + else: + return z_q_x, z_q_x_bar \ No newline at end of file diff --git a/libs/custom_transform.py b/libs/custom_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd980e7890ad3aca4359534d26fab6442173e46 --- /dev/null +++ b/libs/custom_transform.py @@ -0,0 +1,249 @@ +import torch +import torchvision +from torchvision import transforms +import torch.nn.functional as F +import torchvision.transforms.functional as TF +import numpy as np +from PIL import Image, ImageFilter +import random + +class BaseTransform(object): + """ + Resize and center crop. + """ + def __init__(self, res): + self.res = res + + def __call__(self, index, image): + image = TF.resize(image, self.res, Image.BILINEAR) + w, h = image.size + left = int(round((w - self.res) / 2.)) + top = int(round((h - self.res) / 2.)) + + return TF.crop(image, top, left, self.res, self.res) + + +class ComposeTransform(object): + def __init__(self, tlist): + self.tlist = tlist + + def __call__(self, index, image): + for trans in self.tlist: + image = trans(index, image) + + return image + +class RandomResize(object): + def __init__(self, rmin, rmax, N): + self.reslist = [random.randint(rmin, rmax) for _ in range(N)] + + def __call__(self, index, image): + return TF.resize(image, self.reslist[index], Image.BILINEAR) + +class RandomCrop(object): + def __init__(self, res, N): + self.res = res + self.cons = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] + + def __call__(self, index, image): + ws, hs = self.cons[index] + w, h = image.size + left = int(round((w-self.res)*ws)) + top = int(round((h-self.res)*hs)) + + return TF.crop(image, top, left, self.res, self.res) + +class RandomHorizontalFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, index, image): + if self.plist[index.cpu()] < self.p_ref: + return TF.hflip(image) + else: + return image + + +class TensorTransform(object): + def __init__(self): + self.to_tensor = transforms.ToTensor() + #self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def __call__(self, image): + image = self.to_tensor(image) + #image = self.normalize(image) + + return image + + +class RandomGaussianBlur(object): + def __init__(self, sigma, p, N): + self.min_x = sigma[0] + self.max_x = sigma[1] + self.del_p = 1 - p + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, index, image): + if self.plist[index] < self.p_ref: + x = self.plist[index] - self.p_ref + m = (self.max_x - self.min_x) / self.del_p + b = self.min_x + s = m * x + b + + return image.filter(ImageFilter.GaussianBlur(radius=s)) + else: + return image + + +class RandomGrayScale(object): + def __init__(self, p, N): + self.grayscale = transforms.RandomGrayscale(p=1.) # Deterministic (We still want flexible out_dim). + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, index, image): + if self.plist[index] < self.p_ref: + return self.grayscale(image) + else: + return image + + +class RandomColorBrightness(object): + def __init__(self, x, p, N): + self.min_x = max(0, 1 - x) + self.max_x = 1 + x + self.p_ref = p + self.plist = np.random.random_sample(N) + self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] + + def __call__(self, index, image): + if self.plist[index] < self.p_ref: + return TF.adjust_brightness(image, self.rlist[index]) + else: + return image + + +class RandomColorContrast(object): + def __init__(self, x, p, N): + self.min_x = max(0, 1 - x) + self.max_x = 1 + x + self.p_ref = p + self.plist = np.random.random_sample(N) + self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] + + def __call__(self, index, image): + if self.plist[index] < self.p_ref: + return TF.adjust_contrast(image, self.rlist[index]) + else: + return image + + +class RandomColorSaturation(object): + def __init__(self, x, p, N): + self.min_x = max(0, 1 - x) + self.max_x = 1 + x + self.p_ref = p + self.plist = np.random.random_sample(N) + self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] + + def __call__(self, index, image): + if self.plist[index] < self.p_ref: + return TF.adjust_saturation(image, self.rlist[index]) + else: + return image + + +class RandomColorHue(object): + def __init__(self, x, p, N): + self.min_x = -x + self.max_x = x + self.p_ref = p + self.plist = np.random.random_sample(N) + self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] + + def __call__(self, index, image): + if self.plist[index] < self.p_ref: + return TF.adjust_hue(image, self.rlist[index]) + else: + return image + + +class RandomVerticalFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([1]) + else: + image_t = image[I].flip([2]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + + + +class RandomHorizontalTensorFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image, is_label=False): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([2]) + else: + image_t = image[I].flip([3]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + + + +class RandomResizedCrop(object): + def __init__(self, N, res, scale=(0.5, 1.0)): + self.res = res + self.scale = scale + self.rscale = [np.random.uniform(*scale) for _ in range(N)] + self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] + + def random_crop(self, idx, img): + ws, hs = self.rcrop[idx] + res1 = int(img.size(-1)) + res2 = int(self.rscale[idx]*res1) + i1 = int(round((res1-res2)*ws)) + j1 = int(round((res1-res2)*hs)) + + return img[:, :, i1:i1+res2, j1:j1+res2] + + + def __call__(self, indice, image): + new_image = [] + res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2? + + for i, idx in enumerate(indice): + img = image[[i]] + img = self.random_crop(idx, img) + img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) + + new_image.append(img) + + new_image = torch.cat(new_image) + + return new_image + + + + + + + + + + + + \ No newline at end of file diff --git a/libs/data_coco_stuff.py b/libs/data_coco_stuff.py new file mode 100644 index 0000000000000000000000000000000000000000..40146ac62219cbc2b4485a64e76a50618703d00d --- /dev/null +++ b/libs/data_coco_stuff.py @@ -0,0 +1,166 @@ +import cv2 +import torch +from PIL import Image +import os.path as osp +import numpy as np +from torch.utils import data +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +import random + +class RandomResizedCrop(object): + def __init__(self, N, res, scale=(0.5, 1.0)): + self.res = res + self.scale = scale + self.rscale = [np.random.uniform(*scale) for _ in range(N)] + self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] + + def random_crop(self, idx, img): + ws, hs = self.rcrop[idx] + res1 = int(img.size(-1)) + res2 = int(self.rscale[idx]*res1) + i1 = int(round((res1-res2)*ws)) + j1 = int(round((res1-res2)*hs)) + + return img[:, :, i1:i1+res2, j1:j1+res2] + + + def __call__(self, indice, image): + new_image = [] + res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2? + + for i, idx in enumerate(indice): + img = image[[i]] + img = self.random_crop(idx, img) + img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) + + new_image.append(img) + + new_image = torch.cat(new_image) + + return new_image + +class RandomVerticalFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([1]) + else: + image_t = image[I].flip([2]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + +class RandomHorizontalTensorFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image, is_label=False): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([2]) + else: + image_t = image[I].flip([3]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + +class _Coco164kCuratedFew(data.Dataset): + """Base class + This contains fields and methods common to all COCO 164k curated few datasets: + + (curated) Coco164kFew_Stuff + (curated) Coco164kFew_Stuff_People + (curated) Coco164kFew_Stuff_Animals + (curated) Coco164kFew_Stuff_People_Animals + + """ + def __init__(self, root, img_size, crop_size, split = "train2017"): + super(_Coco164kCuratedFew, self).__init__() + + # work out name + self.split = split + self.root = root + self.include_things_labels = False # people + self.incl_animal_things = False # animals + + version = 6 + + name = "Coco164kFew_Stuff" + if self.include_things_labels and self.incl_animal_things: + name += "_People_Animals" + elif self.include_things_labels: + name += "_People" + elif self.incl_animal_things: + name += "_Animals" + + self.name = (name + "_%d" % version) + + print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name) + + self._set_files() + + + self.transform = transforms.Compose([ + transforms.RandomChoice([ + transforms.ColorJitter(brightness=0.05), + transforms.ColorJitter(contrast=0.05), + transforms.ColorJitter(saturation=0.01), + transforms.ColorJitter(hue=0.01)]), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.Resize(int(img_size)), + transforms.RandomCrop(crop_size)]) + + N = len(self.files) + self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) + self.random_vertical_flip = RandomVerticalFlip(N=N) + self.random_resized_crop = RandomResizedCrop(N=N, res=self.res1, scale=self.scale) + + + def _set_files(self): + # Create data list by parsing the "images" folder + if self.split in ["train2017", "val2017"]: + file_list = osp.join(self.root, "curated", self.split, self.name + ".txt") + file_list = tuple(open(file_list, "r")) + file_list = [id_.rstrip() for id_ in file_list] + + self.files = file_list + print("In total {} images.".format(len(self.files))) + else: + raise ValueError("Invalid split name: {}".format(self.split)) + + def __getitem__(self, index): + # same as _Coco164k + # Set paths + image_id = self.files[index] + image_path = osp.join(self.root, "images", self.split, image_id + ".jpg") + label_path = osp.join(self.root, "annotations", self.split, + image_id + ".png") + # Load an image + #image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.uint8) + ori_img = Image.open(image_path) + ori_img = self.transform(ori_img) + ori_img = np.array(ori_img) + if ori_img.ndim < 3: + ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2) + ori_img = ori_img[:, :, :3] + ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1) + ori_img = ori_img / 255.0 + + #label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE).astype(np.int32) + + #label[label == 255] = -1 # to be consistent with 10k + + rets = [] + rets.append(ori_img) + #rets.append(label) + return rets + + def __len__(self): + return len(self.files) diff --git a/libs/data_coco_stuff_geo_pho.py b/libs/data_coco_stuff_geo_pho.py new file mode 100644 index 0000000000000000000000000000000000000000..53344c1fe090e79f0350649f30acbf5b9a6988b9 --- /dev/null +++ b/libs/data_coco_stuff_geo_pho.py @@ -0,0 +1,145 @@ +import cv2 +import torch +from PIL import Image +import os.path as osp +import numpy as np +from torch.utils import data +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +import torchvision.transforms.functional as TF +from .custom_transform import * + +class _Coco164kCuratedFew(data.Dataset): + """Base class + This contains fields and methods common to all COCO 164k curated few datasets: + + (curated) Coco164kFew_Stuff + (curated) Coco164kFew_Stuff_People + (curated) Coco164kFew_Stuff_Animals + (curated) Coco164kFew_Stuff_People_Animals + + """ + def __init__(self, root, img_size, crop_size, split = "train2017"): + super(_Coco164kCuratedFew, self).__init__() + + # work out name + self.split = split + self.root = root + self.include_things_labels = False # people + self.incl_animal_things = False # animals + + version = 6 + + name = "Coco164kFew_Stuff" + if self.include_things_labels and self.incl_animal_things: + name += "_People_Animals" + elif self.include_things_labels: + name += "_People" + elif self.incl_animal_things: + name += "_Animals" + + self.name = (name + "_%d" % version) + + print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name) + + self._set_files() + + self.transform = transforms.Compose([ + transforms.Resize(int(img_size)), + transforms.RandomCrop(crop_size)]) + + N = len(self.files) + # eqv transform + self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) + self.random_vertical_flip = RandomVerticalFlip(N=N) + self.random_resized_crop = RandomResizedCrop(N=N, res=288) + + # photometric transform + self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)] + self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) + self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) + self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) + self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)] + self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)] + + self.eqv_list = ['random_crop', 'h_flip'] + self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur'] + + self.transform_tensor = TensorTransform() + + + def _set_files(self): + # Create data list by parsing the "images" folder + if self.split in ["train2017", "val2017"]: + file_list = osp.join(self.root, "curated", self.split, self.name + ".txt") + file_list = tuple(open(file_list, "r")) + file_list = [id_.rstrip() for id_ in file_list] + + self.files = file_list + print("In total {} images.".format(len(self.files))) + else: + raise ValueError("Invalid split name: {}".format(self.split)) + + def transform_eqv(self, indice, image): + if 'random_crop' in self.eqv_list: + image = self.random_resized_crop(indice, image) + if 'h_flip' in self.eqv_list: + image = self.random_horizontal_flip(indice, image) + if 'v_flip' in self.eqv_list: + image = self.random_vertical_flip(indice, image) + + return image + + def transform_inv(self, index, image, ver): + """ + Hyperparameters same as MoCo v2. + (https://github.com/facebookresearch/moco/blob/master/main_moco.py) + """ + if 'brightness' in self.inv_list: + image = self.random_color_brightness[ver](index, image) + if 'contrast' in self.inv_list: + image = self.random_color_contrast[ver](index, image) + if 'saturation' in self.inv_list: + image = self.random_color_saturation[ver](index, image) + if 'hue' in self.inv_list: + image = self.random_color_hue[ver](index, image) + if 'gray' in self.inv_list: + image = self.random_gray_scale[ver](index, image) + if 'blur' in self.inv_list: + image = self.random_gaussian_blur[ver](index, image) + + return image + + def transform_image(self, index, image): + image1 = self.transform_inv(index, image, 0) + image1 = self.transform_tensor(image) + + image2 = self.transform_inv(index, image, 1) + #image2 = TF.resize(image2, self.crop_size, Image.BILINEAR) + image2 = self.transform_tensor(image2) + return image1, image2 + + def __getitem__(self, index): + # same as _Coco164k + # Set paths + image_id = self.files[index] + image_path = osp.join(self.root, "images", self.split, image_id + ".jpg") + # Load an image + ori_img = Image.open(image_path) + ori_img = self.transform(ori_img) + + image1, image2 = self.transform_image(index, ori_img) + if image1.shape[0] < 3: + image1 = image1.repeat(3, 1, 1) + if image2.shape[0] < 3: + image2 = image2.repeat(3, 1, 1) + + rets = [] + rets.append(image1) + rets.append(image2) + rets.append(index) + + return rets + + def __len__(self): + return len(self.files) diff --git a/libs/data_geo.py b/libs/data_geo.py new file mode 100644 index 0000000000000000000000000000000000000000..d97d85a4aeb747973c18e36cc310a517474ffaa7 --- /dev/null +++ b/libs/data_geo.py @@ -0,0 +1,176 @@ +"""SLIC dataset + - Returns an image together with its SLIC segmentation map. +""" +import torch +import torch.utils.data as data +import torchvision.transforms as transforms + +import numpy as np +from glob import glob +from PIL import Image +from skimage.segmentation import slic +from skimage.color import rgb2lab +import torch.nn.functional as F + +from .utils import label2one_hot_torch + +class RandomResizedCrop(object): + def __init__(self, N, res, scale=(0.5, 1.0)): + self.res = res + self.scale = scale + self.rscale = [np.random.uniform(*scale) for _ in range(N)] + self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] + + def random_crop(self, idx, img): + ws, hs = self.rcrop[idx] + res1 = int(img.size(-1)) + res2 = int(self.rscale[idx]*res1) + i1 = int(round((res1-res2)*ws)) + j1 = int(round((res1-res2)*hs)) + + return img[:, :, i1:i1+res2, j1:j1+res2] + + + def __call__(self, indice, image): + new_image = [] + res_tar = self.res // 8 if image.size(1) > 5 else self.res # View 1 or View 2? + + for i, idx in enumerate(indice): + img = image[[i]] + img = self.random_crop(idx, img) + img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) + + new_image.append(img) + + new_image = torch.cat(new_image) + + return new_image + +class RandomVerticalFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([1]) + else: + image_t = image[I].flip([2]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + +class RandomHorizontalTensorFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image, is_label=False): + I = np.nonzero(self.plist[indice.cpu()] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([2]) + else: + image_t = image[I].flip([3]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + +class Dataset(data.Dataset): + def __init__(self, data_dir, img_size=256, crop_size=128, test=False, + sp_num=256, slic = True, lab = False): + super(Dataset, self).__init__() + #self.data_list = glob(os.path.join(data_dir, "*.jpg")) + ext = ["*.jpg"] + dl = [] + [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] + self.data_list = dl + self.sp_num = sp_num + self.slic = slic + self.lab = lab + if test: + self.transform = transforms.Compose([ + transforms.Resize(img_size), + transforms.CenterCrop(crop_size)]) + else: + self.transform = transforms.Compose([ + transforms.RandomChoice([ + transforms.ColorJitter(brightness=0.05), + transforms.ColorJitter(contrast=0.05), + transforms.ColorJitter(saturation=0.01), + transforms.ColorJitter(hue=0.01)]), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.Resize(int(img_size)), + transforms.RandomCrop(crop_size)]) + + N = len(self.data_list) + self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) + self.random_vertical_flip = RandomVerticalFlip(N=N) + self.random_resized_crop = RandomResizedCrop(N=N, res=224) + self.eqv_list = ['random_crop', 'h_flip'] + + def transform_eqv(self, indice, image): + if 'random_crop' in self.eqv_list: + image = self.random_resized_crop(indice, image) + if 'h_flip' in self.eqv_list: + image = self.random_horizontal_flip(indice, image) + if 'v_flip' in self.eqv_list: + image = self.random_vertical_flip(indice, image) + + return image + + def __getitem__(self, index): + data_path = self.data_list[index] + ori_img = Image.open(data_path) + ori_img = self.transform(ori_img) + ori_img = np.array(ori_img) + + # compute slic + if self.slic: + slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3) + slic_i = torch.from_numpy(slic_i) + slic_i[slic_i >= self.sp_num] = self.sp_num - 1 + oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze() + + if ori_img.ndim < 3: + ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2) + ori_img = ori_img[:, :, :3] + + rets = [] + if self.lab: + lab_img = rgb2lab(ori_img) + rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1)) + + ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1) + rets.append(ori_img/255.0) + + if self.slic: + rets.append(oh) + + rets.append(index) + + return rets + + def __len__(self): + return len(self.data_list) + +if __name__ == '__main__': + import torchvision.utils as vutils + dataset = Dataset('/home/xtli/DATA/texture_data/', + sampled_num=3000) + loader_ = torch.utils.data.DataLoader(dataset = dataset, + batch_size = 1, + shuffle = True, + num_workers = 1, + drop_last = True) + loader = iter(loader_) + img, points, pixs = loader.next() + + crop_size = 128 + canvas = torch.zeros((1, 3, crop_size, crop_size)) + for i in range(points.shape[-2]): + p = (points[0, i] + 1) / 2.0 * (crop_size - 1) + canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] + vutils.save_image(canvas, 'canvas.png') + vutils.save_image(img, 'img.png') diff --git a/libs/data_geo_pho.py b/libs/data_geo_pho.py new file mode 100644 index 0000000000000000000000000000000000000000..61d1681e3b7520b8e2568b6a5d315be85ffc70cf --- /dev/null +++ b/libs/data_geo_pho.py @@ -0,0 +1,130 @@ +"""SLIC dataset + - Returns an image together with its SLIC segmentation map. +""" +import torch +import torch.utils.data as data +import torchvision.transforms as transforms + +import numpy as np +from glob import glob +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms.functional as TF + +from .custom_transform import * + +class Dataset(data.Dataset): + def __init__(self, data_dir, img_size=256, crop_size=128, test=False, + sp_num=256, slic = True, lab = False): + super(Dataset, self).__init__() + #self.data_list = glob(os.path.join(data_dir, "*.jpg")) + ext = ["*.jpg"] + dl = [] + [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] + self.data_list = dl + self.sp_num = sp_num + self.slic = slic + self.lab = lab + if test: + self.transform = transforms.Compose([ + transforms.Resize(img_size), + transforms.CenterCrop(crop_size)]) + else: + self.transform = transforms.Compose([ + transforms.Resize(int(img_size)), + transforms.RandomCrop(crop_size)]) + + N = len(self.data_list) + # eqv transform + self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) + self.random_vertical_flip = RandomVerticalFlip(N=N) + self.random_resized_crop = RandomResizedCrop(N=N, res=256) + + # photometric transform + self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)] + self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) + self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) + self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) + self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)] + self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)] + + self.eqv_list = ['random_crop', 'h_flip'] + self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur'] + + self.transform_tensor = TensorTransform() + + def transform_eqv(self, indice, image): + if 'random_crop' in self.eqv_list: + image = self.random_resized_crop(indice, image) + if 'h_flip' in self.eqv_list: + image = self.random_horizontal_flip(indice, image) + if 'v_flip' in self.eqv_list: + image = self.random_vertical_flip(indice, image) + + return image + + def transform_inv(self, index, image, ver): + """ + Hyperparameters same as MoCo v2. + (https://github.com/facebookresearch/moco/blob/master/main_moco.py) + """ + if 'brightness' in self.inv_list: + image = self.random_color_brightness[ver](index, image) + if 'contrast' in self.inv_list: + image = self.random_color_contrast[ver](index, image) + if 'saturation' in self.inv_list: + image = self.random_color_saturation[ver](index, image) + if 'hue' in self.inv_list: + image = self.random_color_hue[ver](index, image) + if 'gray' in self.inv_list: + image = self.random_gray_scale[ver](index, image) + if 'blur' in self.inv_list: + image = self.random_gaussian_blur[ver](index, image) + + return image + + def transform_image(self, index, image): + image1 = self.transform_inv(index, image, 0) + image1 = self.transform_tensor(image) + + image2 = self.transform_inv(index, image, 1) + #image2 = TF.resize(image2, self.crop_size, Image.BILINEAR) + image2 = self.transform_tensor(image2) + return image1, image2 + + def __getitem__(self, index): + data_path = self.data_list[index] + ori_img = Image.open(data_path) + ori_img = self.transform(ori_img) + + image1, image2 = self.transform_image(index, ori_img) + + rets = [] + rets.append(image1) + rets.append(image2) + rets.append(index) + + return rets + + def __len__(self): + return len(self.data_list) + +if __name__ == '__main__': + import torchvision.utils as vutils + dataset = Dataset('/home/xtli/DATA/texture_data/', + sampled_num=3000) + loader_ = torch.utils.data.DataLoader(dataset = dataset, + batch_size = 1, + shuffle = True, + num_workers = 1, + drop_last = True) + loader = iter(loader_) + img, points, pixs = loader.next() + + crop_size = 128 + canvas = torch.zeros((1, 3, crop_size, crop_size)) + for i in range(points.shape[-2]): + p = (points[0, i] + 1) / 2.0 * (crop_size - 1) + canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] + vutils.save_image(canvas, 'canvas.png') + vutils.save_image(img, 'img.png') diff --git a/libs/data_slic.py b/libs/data_slic.py new file mode 100644 index 0000000000000000000000000000000000000000..52b407c25695aa9cf849e6e64282a6a1be84ccc6 --- /dev/null +++ b/libs/data_slic.py @@ -0,0 +1,175 @@ +"""SLIC dataset + - Returns an image together with its SLIC segmentation map. +""" +import torch +import torch.utils.data as data +import torchvision.transforms as transforms + +import numpy as np +from glob import glob +from PIL import Image +from skimage.segmentation import slic +from skimage.color import rgb2lab + +from .utils import label2one_hot_torch + +class RandomResizedCrop(object): + def __init__(self, N, res, scale=(0.5, 1.0)): + self.res = res + self.scale = scale + self.rscale = [np.random.uniform(*scale) for _ in range(N)] + self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] + + def random_crop(self, idx, img): + ws, hs = self.rcrop[idx] + res1 = int(img.size(-1)) + res2 = int(self.rscale[idx]*res1) + i1 = int(round((res1-res2)*ws)) + j1 = int(round((res1-res2)*hs)) + + return img[:, :, i1:i1+res2, j1:j1+res2] + + + def __call__(self, indice, image): + new_image = [] + res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2? + + for i, idx in enumerate(indice): + img = image[[i]] + img = self.random_crop(idx, img) + img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) + + new_image.append(img) + + new_image = torch.cat(new_image) + + return new_image + +class RandomVerticalFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([1]) + else: + image_t = image[I].flip([2]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + +class RandomHorizontalTensorFlip(object): + def __init__(self, N, p=0.5): + self.p_ref = p + self.plist = np.random.random_sample(N) + + def __call__(self, indice, image, is_label=False): + I = np.nonzero(self.plist[indice] < self.p_ref)[0] + + if len(image.size()) == 3: + image_t = image[I].flip([2]) + else: + image_t = image[I].flip([3]) + + return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) + +class Dataset(data.Dataset): + def __init__(self, data_dir, img_size=256, crop_size=128, test=False, + sp_num=256, slic = True, lab = False): + super(Dataset, self).__init__() + #self.data_list = glob(os.path.join(data_dir, "*.jpg")) + ext = ["*.jpg"] + dl = [] + [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] + self.data_list = dl + self.sp_num = sp_num + self.slic = slic + self.lab = lab + if test: + self.transform = transforms.Compose([ + transforms.Resize(img_size), + transforms.CenterCrop(crop_size)]) + else: + self.transform = transforms.Compose([ + transforms.RandomChoice([ + transforms.ColorJitter(brightness=0.05), + transforms.ColorJitter(contrast=0.05), + transforms.ColorJitter(saturation=0.01), + transforms.ColorJitter(hue=0.01)]), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.Resize(int(img_size)), + transforms.RandomCrop(crop_size)]) + + N = len(self.data_list) + self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) + self.random_vertical_flip = RandomVerticalFlip(N=N) + self.random_resized_crop = RandomResizedCrop(N=N, res=img_size) + self.eqv_list = ['random_crop', 'h_flip'] + + def transform_eqv(self, indice, image): + if 'random_crop' in self.eqv_list: + image = self.random_resized_crop(indice, image) + if 'h_flip' in self.eqv_list: + image = self.random_horizontal_flip(indice, image) + if 'v_flip' in self.eqv_list: + image = self.random_vertical_flip(indice, image) + + return image + + def __getitem__(self, index): + data_path = self.data_list[index] + ori_img = Image.open(data_path) + ori_img = self.transform(ori_img) + ori_img = np.array(ori_img) + + # compute slic + if self.slic: + slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3) + slic_i = torch.from_numpy(slic_i) + slic_i[slic_i >= self.sp_num] = self.sp_num - 1 + oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze() + + if ori_img.ndim < 3: + ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2) + ori_img = ori_img[:, :, :3] + + rets = [] + if self.lab: + lab_img = rgb2lab(ori_img) + rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1)) + + ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1) + rets.append(ori_img/255.0) + + if self.slic: + rets.append(oh) + + rets.append(index) + + return rets + + def __len__(self): + return len(self.data_list) + +if __name__ == '__main__': + import torchvision.utils as vutils + dataset = Dataset('/home/xtli/DATA/texture_data/', + sampled_num=3000) + loader_ = torch.utils.data.DataLoader(dataset = dataset, + batch_size = 1, + shuffle = True, + num_workers = 1, + drop_last = True) + loader = iter(loader_) + img, points, pixs = loader.next() + + crop_size = 128 + canvas = torch.zeros((1, 3, crop_size, crop_size)) + for i in range(points.shape[-2]): + p = (points[0, i] + 1) / 2.0 * (crop_size - 1) + canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] + vutils.save_image(canvas, 'canvas.png') + vutils.save_image(img, 'img.png') diff --git a/libs/discriminator.py b/libs/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..29182012758597ef61b91ffa80a0dc347ff5cbdd --- /dev/null +++ b/libs/discriminator.py @@ -0,0 +1,60 @@ +import functools +import torch.nn as nn + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + norm_layer = nn.BatchNorm2d + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) \ No newline at end of file diff --git a/libs/flow_transforms.py b/libs/flow_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..1386fe01bf6073ce1de2a3b96a6cccf7f6a6880e --- /dev/null +++ b/libs/flow_transforms.py @@ -0,0 +1,393 @@ +from __future__ import division +import torch +import random +import numpy as np +import numbers +import types +import scipy.ndimage as ndimage +import cv2 +import matplotlib.pyplot as plt +from PIL import Image +# import torchvision.transforms.functional as FF + +''' +Data argumentation file +modifed from +https://github.com/ClementPinard/FlowNetPytorch + + +''' + + + +'''Set of tranform random routines that takes both input and target as arguments, +in order to have random but coherent transformations. +inputs are PIL Image pairs and targets are ndarrays''' + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + +class Compose(object): + """ Composes several co_transforms together. + For example: + >>> co_transforms.Compose([ + >>> co_transforms.CenterCrop(10), + >>> co_transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, co_transforms): + self.co_transforms = co_transforms + + def __call__(self, input, target): + for t in self.co_transforms: + input,target = t(input,target) + return input,target + + +class ArrayToTensor(object): + """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).""" + + def __call__(self, array): + assert(isinstance(array, np.ndarray)) + + array = np.transpose(array, (2, 0, 1)) + # handle numpy array + tensor = torch.from_numpy(array) + # put it from HWC to CHW format + + return tensor.float() + + +class ArrayToPILImage(object): + """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).""" + + def __call__(self, array): + assert(isinstance(array, np.ndarray)) + + img = Image.fromarray(array.astype(np.uint8)) + + return img + +class PILImageToTensor(object): + """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).""" + + def __call__(self, img): + assert(isinstance(img, Image.Image)) + + array = np.asarray(img) + array = np.transpose(array, (2, 0, 1)) + tensor = torch.from_numpy(array) + + return tensor.float() + + +class Lambda(object): + """Applies a lambda as a transform""" + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, input,target): + return self.lambd(input,target) + + +class CenterCrop(object): + """Crops the given inputs and target arrays at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + Careful, img1 and img2 may not be the same size + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, inputs, target): + h1, w1, _ = inputs[0].shape + # h2, w2, _ = inputs[1].shape + th, tw = self.size + x1 = int(round((w1 - tw) / 2.)) + y1 = int(round((h1 - th) / 2.)) + # x2 = int(round((w2 - tw) / 2.)) + # y2 = int(round((h2 - th) / 2.)) + for i in range(len(inputs)): + inputs[i] = inputs[i][y1: y1 + th, x1: x1 + tw] + # inputs[0] = inputs[0][y1: y1 + th, x1: x1 + tw] + # inputs[1] = inputs[1][y2: y2 + th, x2: x2 + tw] + target = target[y1: y1 + th, x1: x1 + tw] + return inputs,target + +class myRandomResized(object): + """ + based on RandomResizedCrop in + https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#RandomResizedCrop + """ + + def __init__(self, expect_min_size, scale=(0.8, 1.5), interpolation=cv2.INTER_NEAREST): + # assert (min(input_size) * min(scale) > max(expect_size)) + # one consider one decimal !! + assert (isinstance(scale,tuple) and len(scale)==2) + self.interpolation = interpolation + self.scale = [ x*0.1 for x in range(int(scale[0]*10),int(scale[1])*10 )] + self.min_size = expect_min_size + + @staticmethod + def get_params(img, scale, min_size): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + # area = img.size[0] * img.size[1] + h, w, _ = img.shape + for attempt in range(10): + rand_scale_ = random.choice(scale) + + if random.random() < 0.5: + rand_scale = rand_scale_ + else: + rand_scale = -1. + + if min_size[0] <= rand_scale * h and min_size[1] <= rand_scale * w\ + and rand_scale * h % 16 == 0 and rand_scale * w %16 ==0 : + # the 16*n condition is for network architecture + return (int(rand_scale * h),int(rand_scale * w )) + + # Fallback + return (h, w) + + def __call__(self, inputs, tgt): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + h,w = self.get_params(inputs[0], self.scale, self.min_size) + for i in range(len(inputs)): + inputs[i] = cv2.resize(inputs[i], (w,h), self.interpolation) + + tgt = cv2.resize(tgt, (w,h), self.interpolation) #for input as h*w*1 the output is h*w + return inputs, np.expand_dims(tgt,-1) + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(min_size={0}'.format(self.min_size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class Scale(object): + """ Rescales the inputs and target arrays to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation order: Default: 2 (bilinear) + """ + + def __init__(self, size, order=2): + self.size = size + self.order = order + + def __call__(self, inputs, target): + h, w, _ = inputs[0].shape + if (w <= h and w == self.size) or (h <= w and h == self.size): + return inputs,target + if w < h: + ratio = self.size/w + else: + ratio = self.size/h + + for i in range(len(inputs)): + inputs[i] = ndimage.interpolation.zoom(inputs[i], ratio, order=self.order)[:, :, :3] + + target = ndimage.interpolation.zoom(target, ratio, order=self.order)[:, :, :1] + #target *= ratio + return inputs, target + + +class RandomCrop(object): + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, inputs,target): + h, w, _ = inputs[0].shape + th, tw = self.size + if w == tw and h == th: + return inputs,target + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + for i in range(len(inputs)): + inputs[i] = inputs[i][y1: y1 + th,x1: x1 + tw] + # inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw] + # inputs[2] = inputs[2][y1: y1 + th, x1: x1 + tw] + + return inputs, target[y1: y1 + th,x1: x1 + tw] + +class MyScale(object): + def __init__(self, size, order=2): + self.size = size + self.order = order + + def __call__(self, inputs, target): + h, w, _ = inputs[0].shape + if (w <= h and w == self.size) or (h <= w and h == self.size): + return inputs,target + if w < h: + for i in range(len(inputs)): + inputs[i] = cv2.resize(inputs[i], (self.size, int(h * self.size / w))) + target = cv2.resize(target.squeeze(), (self.size, int(h * self.size / w)), cv2.INTER_NEAREST) + else: + for i in range(len(inputs)): + inputs[i] = cv2.resize(inputs[i], (int(w * self.size / h), self.size)) + target = cv2.resize(target.squeeze(), (int(w * self.size / h), self.size), cv2.INTER_NEAREST) + target = np.expand_dims(target, axis=2) + return inputs, target + +class RandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + + def __call__(self, inputs, target): + if random.random() < 0.5: + for i in range(len(inputs)): + inputs[i] = np.copy(np.fliplr(inputs[i])) + # inputs[1] = np.copy(np.fliplr(inputs[1])) + # inputs[2] = np.copy(np.fliplr(inputs[2])) + + target = np.copy(np.fliplr(target)) + # target[:,:,0] *= -1 + return inputs,target + + +class RandomVerticalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + + def __call__(self, inputs, target): + if random.random() < 0.5: + for i in range(len(inputs)): + inputs[i] = np.copy(np.flipud(inputs[i])) + # inputs[1] = np.copy(np.flipud(inputs[1])) + # inputs[2] = np.copy(np.flipud(inputs[2])) + + target = np.copy(np.flipud(target)) + # target[:,:,1] *= -1 #for disp there is no y dim + return inputs,target + + +class RandomRotate(object): + """Random rotation of the image from -angle to angle (in degrees) + This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation + angle: max angle of the rotation + interpolation order: Default: 2 (bilinear) + reshape: Default: false. If set to true, image size will be set to keep every pixel in the image. + diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off. + """ + + def __init__(self, angle, diff_angle=0, order=2, reshape=False): + self.angle = angle + self.reshape = reshape + self.order = order + self.diff_angle = diff_angle + + def __call__(self, inputs,target): + applied_angle = random.uniform(-self.angle,self.angle) + diff = random.uniform(-self.diff_angle,self.diff_angle) + angle1 = applied_angle - diff/2 + angle2 = applied_angle + diff/2 + angle1_rad = angle1*np.pi/180 + + h, w, _ = target.shape + + def rotate_flow(i,j,k): + return -k*(j-w/2)*(diff*np.pi/180) + (1-k)*(i-h/2)*(diff*np.pi/180) + + rotate_flow_map = np.fromfunction(rotate_flow, target.shape) + target += rotate_flow_map + + inputs[0] = ndimage.interpolation.rotate(inputs[0], angle1, reshape=self.reshape, order=self.order) + inputs[1] = ndimage.interpolation.rotate(inputs[1], angle2, reshape=self.reshape, order=self.order) + target = ndimage.interpolation.rotate(target, angle1, reshape=self.reshape, order=self.order) + # flow vectors must be rotated too! careful about Y flow which is upside down + target_ = np.copy(target) + target[:,:,0] = np.cos(angle1_rad)*target_[:,:,0] + np.sin(angle1_rad)*target_[:,:,1] + target[:,:,1] = -np.sin(angle1_rad)*target_[:,:,0] + np.cos(angle1_rad)*target_[:,:,1] + return inputs,target + + +class RandomTranslate(object): + def __init__(self, translation): + if isinstance(translation, numbers.Number): + self.translation = (int(translation), int(translation)) + else: + self.translation = translation + + def __call__(self, inputs,target): + h, w, _ = inputs[0].shape + th, tw = self.translation + tw = random.randint(-tw, tw) + th = random.randint(-th, th) + if tw == 0 and th == 0: + return inputs, target + # compute x1,x2,y1,y2 for img1 and target, and x3,x4,y3,y4 for img2 + x1,x2,x3,x4 = max(0,tw), min(w+tw,w), max(0,-tw), min(w-tw,w) + y1,y2,y3,y4 = max(0,th), min(h+th,h), max(0,-th), min(h-th,h) + + inputs[0] = inputs[0][y1:y2,x1:x2] + inputs[1] = inputs[1][y3:y4,x3:x4] + target = target[y1:y2,x1:x2] + target[:,:,0] += tw + target[:,:,1] += th + + return inputs, target + + +class RandomColorWarp(object): + def __init__(self, mean_range=0, std_range=0): + self.mean_range = mean_range + self.std_range = std_range + + def __call__(self, inputs, target): + random_std = np.random.uniform(-self.std_range, self.std_range, 3) + random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3) + random_order = np.random.permutation(3) + + inputs[0] *= (1 + random_std) + inputs[0] += random_mean + + inputs[1] *= (1 + random_std) + inputs[1] += random_mean + + inputs[0] = inputs[0][:,:,random_order] + inputs[1] = inputs[1][:,:,random_order] + + return inputs, target diff --git a/libs/losses.py b/libs/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb9b67159407d49835029f95fd29ef185a50990 --- /dev/null +++ b/libs/losses.py @@ -0,0 +1,416 @@ +from libs.blocks import encoder5 +import torch +import torchvision +import torch.nn as nn +from torch.nn import init +import torch.nn.functional as F +from .normalization import get_nonspade_norm_layer +from .blocks import encoder5 + +import os +import numpy as np + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' + % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + +class NLayerDiscriminator(BaseNetwork): + def __init__(self): + super().__init__() + + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + nf = 64 + n_layers_D = 4 + input_nc = 3 + + norm_layer = get_nonspade_norm_layer('spectralinstance') + sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, False)]] + + for n in range(1, n_layers_D): + nf_prev = nf + nf = min(nf * 2, 512) + stride = 1 if n == n_layers_D - 1 else 2 + sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, + stride=stride, padding=padw)), + nn.LeakyReLU(0.2, False) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + # We divide the layers into groups to extract intermediate layer outputs + for n in range(len(sequence)): + self.add_module('model' + str(n), nn.Sequential(*sequence[n])) + + def forward(self, input, get_intermediate_features = True): + results = [input] + for submodel in self.children(): + intermediate_output = submodel(results[-1]) + results.append(intermediate_output) + + if get_intermediate_features: + return results[1:] + else: + return results[-1] + +class VGG19(torch.nn.Module): + def __init__(self, requires_grad=False): + super().__init__() + vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + import pdb; pdb.set_trace() + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + +class encoder5(nn.Module): + def __init__(self): + super(encoder5,self).__init__() + # vgg + # 224 x 224 + self.conv1 = nn.Conv2d(3,3,1,1,0) + self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) + # 226 x 226 + + self.conv2 = nn.Conv2d(3,64,3,1,0) + self.relu2 = nn.ReLU(inplace=True) + # 224 x 224 + + self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) + self.conv3 = nn.Conv2d(64,64,3,1,0) + self.relu3 = nn.ReLU(inplace=True) + # 224 x 224 + + self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) + # 112 x 112 + + self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) + self.conv4 = nn.Conv2d(64,128,3,1,0) + self.relu4 = nn.ReLU(inplace=True) + # 112 x 112 + + self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) + self.conv5 = nn.Conv2d(128,128,3,1,0) + self.relu5 = nn.ReLU(inplace=True) + # 112 x 112 + + self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) + # 56 x 56 + + self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) + self.conv6 = nn.Conv2d(128,256,3,1,0) + self.relu6 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) + self.conv7 = nn.Conv2d(256,256,3,1,0) + self.relu7 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) + self.conv8 = nn.Conv2d(256,256,3,1,0) + self.relu8 = nn.ReLU(inplace=True) + # 56 x 56 + + self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) + self.conv9 = nn.Conv2d(256,256,3,1,0) + self.relu9 = nn.ReLU(inplace=True) + # 56 x 56 + + self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) + # 28 x 28 + + self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) + self.conv10 = nn.Conv2d(256,512,3,1,0) + self.relu10 = nn.ReLU(inplace=True) + + self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) + self.conv11 = nn.Conv2d(512,512,3,1,0) + self.relu11 = nn.ReLU(inplace=True) + + self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) + self.conv12 = nn.Conv2d(512,512,3,1,0) + self.relu12 = nn.ReLU(inplace=True) + + self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) + self.conv13 = nn.Conv2d(512,512,3,1,0) + self.relu13 = nn.ReLU(inplace=True) + + self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) + self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) + self.conv14 = nn.Conv2d(512,512,3,1,0) + self.relu14 = nn.ReLU(inplace=True) + + def forward(self,x): + output = [] + out = self.conv1(x) + out = self.reflecPad1(out) + out = self.conv2(out) + out = self.relu2(out) + output.append(out) + + out = self.reflecPad3(out) + out = self.conv3(out) + out = self.relu3(out) + out = self.maxPool(out) + out = self.reflecPad4(out) + out = self.conv4(out) + out = self.relu4(out) + output.append(out) + + out = self.reflecPad5(out) + out = self.conv5(out) + out = self.relu5(out) + out = self.maxPool2(out) + out = self.reflecPad6(out) + out = self.conv6(out) + out = self.relu6(out) + output.append(out) + + out = self.reflecPad7(out) + out = self.conv7(out) + out = self.relu7(out) + out = self.reflecPad8(out) + out = self.conv8(out) + out = self.relu8(out) + out = self.reflecPad9(out) + out = self.conv9(out) + out = self.relu9(out) + out = self.maxPool3(out) + out = self.reflecPad10(out) + out = self.conv10(out) + out = self.relu10(out) + output.append(out) + + out = self.reflecPad11(out) + out = self.conv11(out) + out = self.relu11(out) + out = self.reflecPad12(out) + out = self.conv12(out) + out = self.relu12(out) + out = self.reflecPad13(out) + out = self.conv13(out) + out = self.relu13(out) + out = self.maxPool4(out) + out = self.reflecPad14(out) + out = self.conv14(out) + out = self.relu14(out) + + output.append(out) + return output + +class VGGLoss(nn.Module): + def __init__(self, model_path): + super(VGGLoss, self).__init__() + self.vgg = encoder5().cuda() + self.vgg.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) + self.criterion = nn.MSELoss() + self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(4): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + +class GANLoss(nn.Module): + def __init__(self, gan_mode = 'hinge', target_real_label=1.0, target_fake_label=0.0, + tensor=torch.cuda.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + self.zero_tensor = None + self.Tensor = tensor + self.gan_mode = gan_mode + if gan_mode == 'ls': + pass + elif gan_mode == 'original': + pass + elif gan_mode == 'w': + pass + elif gan_mode == 'hinge': + pass + else: + raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + if self.real_label_tensor is None: + self.real_label_tensor = self.Tensor(1).fill_(self.real_label) + self.real_label_tensor.requires_grad_(False) + return self.real_label_tensor.expand_as(input) + else: + if self.fake_label_tensor is None: + self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) + self.fake_label_tensor.requires_grad_(False) + return self.fake_label_tensor.expand_as(input) + + def get_zero_tensor(self, input): + if self.zero_tensor is None: + self.zero_tensor = self.Tensor(1).fill_(0) + self.zero_tensor.requires_grad_(False) + return self.zero_tensor.expand_as(input) + + def loss(self, input, target_is_real, for_discriminator=True): + if self.gan_mode == 'original': # cross entropy loss + target_tensor = self.get_target_tensor(input, target_is_real) + loss = F.binary_cross_entropy_with_logits(input, target_tensor) + return loss + elif self.gan_mode == 'ls': + target_tensor = self.get_target_tensor(input, target_is_real) + return F.mse_loss(input, target_tensor) + elif self.gan_mode == 'hinge': + if for_discriminator: + if target_is_real: + minval = torch.min(input - 1, self.get_zero_tensor(input)) + loss = -torch.mean(minval) + else: + minval = torch.min(-input - 1, self.get_zero_tensor(input)) + loss = -torch.mean(minval) + else: + assert target_is_real, "The generator's hinge loss must be aiming for real" + loss = -torch.mean(input) + return loss + else: + # wgan + if target_is_real: + return -input.mean() + else: + return input.mean() + + def __call__(self, input, target_is_real, for_discriminator=True): + # computing loss is a bit complicated because |input| may not be + # a tensor, but list of tensors in case of multiscale discriminator + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + pred_i = pred_i[-1] + loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) + bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) + new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) + loss += new_loss + return loss / len(input) + else: + return self.loss(input, target_is_real, for_discriminator) + +class SPADE_LOSS(nn.Module): + def __init__(self, model_path, lambda_feat = 1): + super(SPADE_LOSS, self).__init__() + self.criterionVGG = VGGLoss(model_path) + self.criterionGAN = GANLoss('hinge') + self.criterionL1 = nn.L1Loss() + self.discriminator = NLayerDiscriminator() + self.lambda_feat = lambda_feat + + def forward(self, x, y, for_discriminator = False): + pred_real = self.discriminator(y) + if not for_discriminator: + pred_fake = self.discriminator(x) + VGGLoss = self.criterionVGG(x, y) + GANLoss = self.criterionGAN(pred_fake, True, for_discriminator = False) + + # feature matching loss + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake) - 1 + GAN_Feat_loss = 0 + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = self.criterionL1(pred_fake[j], pred_real[j].detach()) + GAN_Feat_loss += unweighted_loss * self.lambda_feat + L1Loss = self.criterionL1(x, y) + return VGGLoss, GANLoss, GAN_Feat_loss, L1Loss + else: + pred_fake = self.discriminator(x.detach()) + GANLoss = self.criterionGAN(pred_fake, False, for_discriminator = True) + GANLoss += self.criterionGAN(pred_real, True, for_discriminator = True) + return GANLoss + +class ContrastiveLoss(nn.Module): + """ + Contrastive loss + Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise + """ + + def __init__(self, margin): + super(ContrastiveLoss, self).__init__() + self.margin = margin + self.eps = 1e-9 + + def forward(self, out1, out2, target, size_average=True, norm = True): + if norm: + output1 = out1 / out1.pow(2).sum(1, keepdim=True).sqrt() + output2 = out1 / out2.pow(2).sum(1, keepdim=True).sqrt() + distances = (output2 - output1).pow(2).sum(1) # squared distances + losses = 0.5 * (target.float() * distances + + (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2)) + return losses.mean() if size_average else losses.sum() diff --git a/libs/nnutils.py b/libs/nnutils.py new file mode 100644 index 0000000000000000000000000000000000000000..3622c5b19199b1c4b837c913e22748a96e3b08c6 --- /dev/null +++ b/libs/nnutils.py @@ -0,0 +1,126 @@ +""" Network utils + - poolfeat: aggregate superpixel features from pixel features + - upfeat: reconstruction pixel features from superpixel features + - quantize: quantization features given a codebook +""" +import torch + +def poolfeat(input, prob, avg = True): + """ A function to aggregate superpixel features from pixel features + + Args: + input (tensor): input feature tensor. + prob (tensor): one-hot superpixel segmentation. + avg (bool, optional): average or sum the pixel features to get superpixel features + + Returns: + cluster_feat (tensor): the superpixel features + + Shape: + input: (B, C, H, W) + prob: (B, N, H, W) + cluster_feat: (B, N, C) + """ + B, C, H, W = input.shape + B, N, H, W = prob.shape + prob_flat = prob.view(B, N, -1) + input_flat = input.view(B, C, -1) + cluster_feat = torch.matmul(prob_flat, input_flat.permute(0, 2, 1)) + if avg: + cluster_sum = torch.sum(prob_flat, dim = -1).view(B, N , 1) + cluster_feat = cluster_feat / (cluster_sum + 1e-8) + return cluster_feat + +def upfeat(input, prob): + """ A function to compute pixel features from superpixel features + + Args: + input (tensor): superpixel feature tensor. + prob (tensor): one-hot superpixel segmentation. + + Returns: + reconstr_feat (tensor): the pixel features. + + Shape: + input: (B, N, C) + prob: (B, N, H, W) + reconstr_feat: (B, C, H, W) + """ + B, N, H, W = prob.shape + prob_flat = prob.view(B, N, -1) + reconstr_feat = torch.matmul(prob_flat.permute(0, 2, 1), input) + reconstr_feat = reconstr_feat.view(B, H, W, -1).permute(0, 3, 1, 2) + + return reconstr_feat + +def quantize(z, embedding, beta = 0.25): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + + Args: + z (tensor): features from the encoder network + embedding (tensor): codebook + beta (scalar, optional): commit loss weight + + Returns: + z_q: quantized features + loss: vq loss + commit loss * beta + min_encodings: quantization assignment one hot vector + min_encoding_indices: quantization assignment + + Shape: + z: B, N, C + embedding: B, K, C + z_q: B, N, C + min_encodings: B, N, K + min_encoding_indices: B, N, 1 + + Note: + Adapted from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py + """ + + # B, 256, 32 + if embedding.shape[0] == 1: + d = torch.sum(z ** 2, dim=2, keepdim=True) + torch.sum(embedding**2, dim=2) - 2 * torch.matmul(z, embedding.transpose(1, 2)) + else: + ds = [] + for i in range(embedding.shape[0]): + z_i = z[i:i+1] + embedding_i = embedding[i:i+1] + ds.append(torch.sum(z_i ** 2, dim=2, keepdim=True) + torch.sum(embedding_i**2, dim=2) - 2 * torch.matmul(z_i, embedding_i.transpose(1, 2))) + d = torch.cat(ds) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=2).unsqueeze(2) # B, 256, 1 + + #min_encodings = torch.zeros( + # min_encoding_indices.shape[0], self.n_e).to(z) + #min_encodings.scatter_(1, min_encoding_indices, 1) + n_e = embedding.shape[1] # 32 + min_encodings = torch.zeros(z.shape[0], z.shape[1], n_e).to(z) + min_encodings.scatter_(2, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, embedding).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + return z_q, loss, (min_encodings, min_encoding_indices, d) \ No newline at end of file diff --git a/libs/normalization.py b/libs/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..9237546d040e2ae682c7b63d35c5f405111376ed --- /dev/null +++ b/libs/normalization.py @@ -0,0 +1,104 @@ +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as spectral_norm + + +# Returns a function that creates a normalization function +# that does not condition on semantic map +def get_nonspade_norm_layer(norm_type='instance'): + # helper function to get # output channels of the previous layer + def get_out_channel(layer): + if hasattr(layer, 'out_channels'): + return getattr(layer, 'out_channels') + return layer.weight.size(0) + + # this function will be returned + def add_norm_layer(layer): + nonlocal norm_type + if norm_type.startswith('spectral'): + layer = spectral_norm(layer) + subnorm_type = norm_type[len('spectral'):] + + if subnorm_type == 'none' or len(subnorm_type) == 0: + return layer + + # remove bias in the previous layer, which is meaningless + # since it has no effect after normalization + if getattr(layer, 'bias', None) is not None: + delattr(layer, 'bias') + layer.register_parameter('bias', None) + + if subnorm_type == 'batch': + norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == 'sync_batch': + norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == 'instance': + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + else: + raise ValueError('normalization layer %s is not recognized' % subnorm_type) + + return nn.Sequential(layer, norm_layer) + + return add_norm_layer + + +# Creates SPADE normalization layer based on the given configuration +# SPADE consists of two steps. First, it normalizes the activations using +# your favorite normalization method, such as Batch Norm or Instance Norm. +# Second, it applies scale and bias to the normalized output, conditioned on +# the segmentation map. +# The format of |config_text| is spade(norm)(ks), where +# (norm) specifies the type of parameter-free normalization. +# (e.g. syncbatch, batch, instance) +# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) +# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. +# Also, the other arguments are +# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE +# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE +class SPADE(nn.Module): + def __init__(self, config_text, norm_nc, label_nc): + super().__init__() + + assert config_text.startswith('spade') + parsed = re.search('spade(\D+)(\d)x\d', config_text) + param_free_norm_type = str(parsed.group(1)) + ks = int(parsed.group(2)) + + if param_free_norm_type == 'instance': + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + elif param_free_norm_type == 'syncbatch': + self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) + elif param_free_norm_type == 'batch': + self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) + else: + raise ValueError('%s is not a recognized param-free norm type in SPADE' + % param_free_norm_type) + + # The dimension of the intermediate embedding space. Yes, hardcoded. + nhidden = 128 + + pw = ks // 2 + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.ReLU() + ) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, x, segmap): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + out = normalized * (1 + gamma) + beta + + return out \ No newline at end of file diff --git a/libs/options.py b/libs/options.py new file mode 100644 index 0000000000000000000000000000000000000000..97cd26da886dc646ffeef40f15049fcc21437f9a --- /dev/null +++ b/libs/options.py @@ -0,0 +1,144 @@ +import os +import json +import argparse + +class BaseOptions(): + def initialize(self, parser): + parser.add_argument('--nChannel', metavar='N', default=100, type=int, + help='number of channels') + parser.add_argument('--maxIter', metavar='T', default=1000, type=int, + help='number of maximum iterations') + parser.add_argument('--lr', metavar='LR', default=0.1, type=float, + help='learning rate') + parser.add_argument('--nConv', metavar='M', default=2, type=int, + help='number of convolutional layers') + parser.add_argument("--work_dir", type=str, default="/home/xli/WORKDIR", + help='project directory') + parser.add_argument("--out_dir", type=str, default=None, + help='logging output') + parser.add_argument("--use_wandb", type=int, default=0, + help='use wandb or not') + parser.add_argument("--data_path", type=str, default="/home/xli/DATA/BSR_processed/train_extend", + help="data path") + parser.add_argument("--img_path", type=str, default=None, + help="image path") + parser.add_argument('--crop_size', type=int, default= 224, + help='crop_size') + parser.add_argument("--batch_size", type=int, default=1, + help='batch size') + parser.add_argument('--workers', type=int, default=4, + help='number of data loading workers') + parser.add_argument("--use_slic", default = 1, type=int, + help="choose to use slic or gt label") + parser.add_argument("-f", "--config_file", type=str, default='models/week0417/json/single_scale_grouping_ft.json', + help='json files including all arguments') + parser.add_argument("--log_freq", type=int, default=10, + help='frequency to print log') + parser.add_argument("--display_freq", type=int, default=100, + help='frequency to save visualization') + parser.add_argument("--pretrained_ae", type=str, + default = "/home/xli/WORKDIR/07-16/transformer/cpk.pth") + parser.add_argument("--pretrained_path", type=str, default=None, + help='pretrained reconstruction model') + parser.add_argument('--momentum', type=float, default=0.5, + help='momentum for sgd, alpha parameter for adam') + parser.add_argument('--beta', type=float, default=0.999, + help='beta parameter for adam') + parser.add_argument("--l1_loss_wt", default=1.0, type=float) + parser.add_argument("--perceptual_loss_wt", default=1.0, type=float) + parser.add_argument('--project_name', type=str, default='test_time', + help='project name') + parser.add_argument("--save_freq", type=int, default=2000, + help='frequency to save model') + parser.add_argument("--local_rank", type=int) + parser.add_argument('--lr_decay_freq', type=int, default=3000, + help='frequency to decay learning rate') + parser.add_argument('--no_ganFeat_loss', action='store_true', + help='if specified, do *not* use discriminator feature matching loss') + parser.add_argument('--sp_num', type=int, default=None, + help='superpixel number') + parser.add_argument('--add_self_loops', type=int, default=1, + help='set to 1 to add self loops in GCNs') + parser.add_argument('--test_time', type=int, default=0, + help='set to 1 to add self loops in GCNs') + parser.add_argument('--add_texture_epoch', type=int, default=1000, + help='when to add texture synthesis') + parser.add_argument('--add_clustering_epoch', type=int, default=1000, + help='when to add grouping') + parser.add_argument('--temperature', type=int, default=1, + help='temperature in SoftMax') + parser.add_argument('--gumbel', type=int, default=0, + help='if use gumbel SoftMax') + parser.add_argument('--patch_size', type=int, default=40, + help='patch size in texture synthesis') + parser.add_argument("--netE_num_downsampling_sp", default=4, type=int) + parser.add_argument('--num_classes', type=int, default=0) + parser.add_argument( + "--netG_num_base_resnet_layers", + default=2, type=int, + help="The number of resnet layers before the upsampling layers." + ) + parser.add_argument("--netG_scale_capacity", default=1.0, type=float) + parser.add_argument("--netG_resnet_ch", type=int, default=256) + + parser.add_argument("--spatial_code_ch", default=8, type=int) + parser.add_argument("--texture_code_ch", default=256, type=int) + parser.add_argument("--netE_scale_capacity", default=1.0, type=float) + parser.add_argument("--netE_num_downsampling_gl", default=2, type=int) + parser.add_argument("--netE_nc_steepness", default=2.0, type=float) + parser.add_argument("--spatial_code_dim", type=int, default=256, help="codebook entry dimension") + return parser + + def print_options(self, opt): + """Print and save options + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + def save_options(self, opt): + os.makedirs(opt.out_dir, exist_ok=True) + file_name = os.path.join(opt.out_dir, 'exp_args.txt') + with open(file_name, 'wt') as opt_file: + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) + opt_file.close() + + def gather_options(self): + parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation') + self.parser = self.initialize(parser) + opt = self.parser.parse_args() + + opt = self.update_with_json(opt) + opt.out_dir = os.path.join(opt.work_dir, opt.exp_name) + opt.use_slic = (opt.use_slic == 1) + opt.use_wandb = (opt.use_wandb == 1) + + # logging + self.print_options(opt) + self.save_options(opt) + return opt + + def update_with_json(self, args): + arg_dict = vars(args) + + # arguments house keeping + with open(args.config_file, 'r') as f: + arg_str = f.read() + file_args = json.loads(arg_str) + arg_dict.update(file_args) + args = argparse.Namespace(**arg_dict) + return args diff --git a/libs/ssn/loss.py b/libs/ssn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3679f5f0ab57a9ff2f0cb921b3dfbd7a6055d018 --- /dev/null +++ b/libs/ssn/loss.py @@ -0,0 +1,61 @@ +import torch + +def reconstruction(assignment, labels, hard_assignment=None): + """ + reconstruction + + Args: + assignment: torch.Tensor + A Tensor of shape (B, n_spixels, n_pixels) + labels: torch.Tensor + A Tensor of shape (B, C, n_pixels) + hard_assignment: torch.Tensor + A Tensor of shape (B, n_pixels) + """ + labels = labels.permute(0, 2, 1).contiguous() + + # matrix product between (n_spixels, n_pixels) and (n_pixels, channels) + spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16) + if hard_assignment is None: + # (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels) + permuted_assignment = assignment.permute(0, 2, 1).contiguous() + # matrix product between (n_pixels, n_spixels) and (n_spixels, channels) + reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean) + else: + # index sampling + reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0) + return reconstructed_labels.permute(0, 2, 1).contiguous() + + +def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None): + """ + reconstruction loss with cross entropy + + Args: + assignment: torch.Tensor + A Tensor of shape (B, n_spixels, n_pixels) + labels: torch.Tensor + A Tensor of shape (B, C, n_pixels) + hard_assignment: torch.Tensor + A Tensor of shape (B, n_pixels) + """ + reconstracted_labels = reconstruction(assignment, labels, hard_assignment) + reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True)) + mask = labels > 0 + return -(reconstracted_labels[mask] + 1e-16).log().mean() + + +def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None): + """ + reconstruction loss with mse + + Args: + assignment: torch.Tensor + A Tensor of shape (B, n_spixels, n_pixels) + labels: torch.Tensor + A Tensor of shape (B, C, n_pixels) + hard_assignment: torch.Tensor + A Tensor of shape (B, n_pixels) + """ + reconstracted_labels = reconstruction(assignment, labels, hard_assignment) + return torch.nn.functional.mse_loss(reconstracted_labels, labels) \ No newline at end of file diff --git a/libs/ssn/pair_wise_distance.py b/libs/ssn/pair_wise_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6755d468b02f344d21069ff73ee611c710f460 --- /dev/null +++ b/libs/ssn/pair_wise_distance.py @@ -0,0 +1,41 @@ +import torch +from torch.utils.cpp_extension import load_inline +from .pair_wise_distance_cuda_source import source + + +print("compile cuda source of 'pair_wise_distance' function...") +print("NOTE: if you avoid this process, you make .cu file and compile it following https://pytorch.org/tutorials/advanced/cpp_extension.html") +pair_wise_distance_cuda = load_inline( + "pair_wise_distance", cpp_sources="", cuda_sources=source +) +print("done") + + +class PairwiseDistFunction(torch.autograd.Function): + @staticmethod + def forward(self, pixel_features, spixel_features, init_spixel_indices, num_spixels_width, num_spixels_height): + self.num_spixels_width = num_spixels_width + self.num_spixels_height = num_spixels_height + output = pixel_features.new(pixel_features.shape[0], 9, pixel_features.shape[-1]).zero_() + self.save_for_backward(pixel_features, spixel_features, init_spixel_indices) + + return pair_wise_distance_cuda.forward( + pixel_features.contiguous(), spixel_features.contiguous(), + init_spixel_indices.contiguous(), output, + self.num_spixels_width, self.num_spixels_height) + + @staticmethod + def backward(self, dist_matrix_grad): + pixel_features, spixel_features, init_spixel_indices = self.saved_tensors + + pixel_features_grad = torch.zeros_like(pixel_features) + spixel_features_grad = torch.zeros_like(spixel_features) + + pixel_features_grad, spixel_features_grad = pair_wise_distance_cuda.backward( + dist_matrix_grad.contiguous(), pixel_features.contiguous(), + spixel_features.contiguous(), init_spixel_indices.contiguous(), + pixel_features_grad, spixel_features_grad, + self.num_spixels_width, self.num_spixels_height + ) + return pixel_features_grad, spixel_features_grad, None, None, None + diff --git a/libs/ssn/pair_wise_distance_cuda_source.py b/libs/ssn/pair_wise_distance_cuda_source.py new file mode 100644 index 0000000000000000000000000000000000000000..49aa12712660754df04e8b37a112d00797342351 --- /dev/null +++ b/libs/ssn/pair_wise_distance_cuda_source.py @@ -0,0 +1,177 @@ +source = ''' +#include +#include +#include +#include + +#define CUDA_NUM_THREADS 256 + +#include +#include +#include +#include + +#include +#include +#include + +template +__global__ void forward_kernel( + const scalar_t* __restrict__ pixel_features, + const scalar_t* __restrict__ spixel_features, + const scalar_t* __restrict__ spixel_indices, + scalar_t* __restrict__ dist_matrix, + int batchsize, int channels, int num_pixels, int num_spixels, + int num_spixels_w, int num_spixels_h +){ + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= batchsize * num_pixels * 9) return; + + int cp = channels * num_pixels; + int cs = channels * num_spixels; + + int b = index % batchsize; + int spixel_offset = (index / batchsize) % 9; + int p = (index / (batchsize * 9)) % num_pixels; + + int init_spix_index = spixel_indices[b * num_pixels + p]; + + int x_index = init_spix_index % num_spixels_w; + int spixel_offset_x = (spixel_offset % 3 - 1); + + int y_index = init_spix_index / num_spixels_w; + int spixel_offset_y = (spixel_offset / 3 - 1); + + if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) { + dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; + } + else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) { + dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; + } + else { + int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; + + scalar_t sum_squared_diff = 0; + for (int c=0; c<<< block, CUDA_NUM_THREADS >>>( + pixel_features.data(), + spixel_features.data(), + spixel_indices.data(), + dist_matrix.data(), + batchsize, channels, num_pixels, + num_spixels, num_spixels_w, num_spixels_h + ); + })); + + return dist_matrix; +} + +template +__global__ void backward_kernel( + const scalar_t* __restrict__ dist_matrix_grad, + const scalar_t* __restrict__ pixel_features, + const scalar_t* __restrict__ spixel_features, + const scalar_t* __restrict__ spixel_indices, + scalar_t* __restrict__ pixel_feature_grad, + scalar_t* __restrict__ spixel_feature_grad, + int batchsize, int channels, int num_pixels, int num_spixels, + int num_spixels_w, int num_spixels_h +){ + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= batchsize * num_pixels * 9) return; + + int cp = channels * num_pixels; + int cs = channels * num_spixels; + + int b = index % batchsize; + int spixel_offset = (index / batchsize) % 9; + int p = (index / (batchsize * 9)) % num_pixels; + + int init_spix_index = spixel_indices[b * num_pixels + p]; + + int x_index = init_spix_index % num_spixels_w; + int spixel_offset_x = (spixel_offset % 3 - 1); + + int y_index = init_spix_index / num_spixels_w; + int spixel_offset_y = (spixel_offset / 3 - 1); + + if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) return; + else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) return; + else { + int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; + + scalar_t dist_matrix_grad_val = dist_matrix_grad[b * (9 * num_pixels) + spixel_offset * num_pixels + p]; + + for (int c=0; c backward_cuda( + const torch::Tensor dist_matrix_grad, + const torch::Tensor pixel_features, + const torch::Tensor spixel_features, + const torch::Tensor spixel_indices, + torch::Tensor pixel_features_grad, + torch::Tensor spixel_features_grad, + int num_spixels_w, int num_spixels_h +){ + int batchsize = pixel_features.size(0); + int channels = pixel_features.size(1); + int num_pixels = pixel_features.size(2); + int num_spixels = spixel_features.size(2); + + + dim3 block((batchsize * 9 * num_pixels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); + + AT_DISPATCH_FLOATING_TYPES(pixel_features_grad.type(), "backward_kernel", ([&] { + backward_kernel<<< block, CUDA_NUM_THREADS >>>( + dist_matrix_grad.data(), + pixel_features.data(), + spixel_features.data(), + spixel_indices.data(), + pixel_features_grad.data(), + spixel_features_grad.data(), + batchsize, channels, num_pixels, + num_spixels, num_spixels_w, num_spixels_h + ); + })); + + return {pixel_features_grad, spixel_features_grad}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward_cuda, "pair_wise_distance forward"); + m.def("backward", &backward_cuda, "pair_wise_distance backward"); +} +''' \ No newline at end of file diff --git a/libs/ssn/sparse_utils.py b/libs/ssn/sparse_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..37b72dfce7216bb3e9c9a2814e957490968f4ea0 --- /dev/null +++ b/libs/ssn/sparse_utils.py @@ -0,0 +1,15 @@ +import torch + + +def naive_sparse_bmm(sparse_mat, dense_mat, transpose=False): + if transpose: + return torch.stack([torch.sparse.mm(s_mat, d_mat.t()) for s_mat, d_mat in zip(sparse_mat, dense_mat)], 0) + else: + return torch.stack([torch.sparse.mm(s_mat, d_mat) for s_mat, d_mat in zip(sparse_mat, dense_mat)], 0) + +def sparse_permute(sparse_mat, order): + values = sparse_mat.coalesce().values() + indices = sparse_mat.coalesce().indices() + indices = torch.stack([indices[o] for o in order], 0).contiguous() + return torch.sparse_coo_tensor(indices, values) + diff --git a/libs/ssn/ssn.py b/libs/ssn/ssn.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8d3c57920a453389ed3223c05046edcff8aaf4 --- /dev/null +++ b/libs/ssn/ssn.py @@ -0,0 +1,193 @@ +import math +import torch + +from .pair_wise_distance import PairwiseDistFunction +from .sparse_utils import naive_sparse_bmm + + +def calc_init_centroid(images, num_spixels_width, num_spixels_height): + """ + calculate initial superpixels + + Args: + images: torch.Tensor + A Tensor of shape (B, C, H, W) + spixels_width: int + initial superpixel width + spixels_height: int + initial superpixel height + + Return: + centroids: torch.Tensor + A Tensor of shape (B, C, H * W) + init_label_map: torch.Tensor + A Tensor of shape (B, H * W) + num_spixels_width: int + A number of superpixels in each column + num_spixels_height: int + A number of superpixels int each raw + """ + batchsize, channels, height, width = images.shape + device = images.device + + centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width)) + + with torch.no_grad(): + num_spixels = num_spixels_width * num_spixels_height + labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids) + init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest") + init_label_map = init_label_map.repeat(batchsize, 1, 1, 1) + + init_label_map = init_label_map.reshape(batchsize, -1) + centroids = centroids.reshape(batchsize, channels, -1) + + return centroids, init_label_map + + +@torch.no_grad() +def get_abs_indices(init_label_map, num_spixels_width): + b, n_pixel = init_label_map.shape + device = init_label_map.device + r = torch.arange(-1, 2.0, device=device) + relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) + + abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long() + abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long() + abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long() + + return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0) + + +@torch.no_grad() +def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width): + relative_label = affinity_matrix.max(1)[1] + r = torch.arange(-1, 2.0, device=affinity_matrix.device) + relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) + label = init_label_map + relative_spix_indices[relative_label] + return label.long() + + +@torch.no_grad() +def sparse_ssn_iter(pixel_features, num_spixels, n_iter): + """ + computing assignment iterations with sparse matrix + detailed process is in Algorithm 1, line 2 - 6 + NOTE: this function does NOT guarantee the backward computation. + + Args: + pixel_features: torch.Tensor + A Tensor of shape (B, C, H, W) + num_spixels: int + A number of superpixels + n_iter: int + A number of iterations + return_hard_label: bool + return hard assignment or not + """ + height, width = pixel_features.shape[-2:] + num_spixels_width = int(math.sqrt(num_spixels * width / height)) + num_spixels_height = int(math.sqrt(num_spixels * height / width)) + + spixel_features, init_label_map = \ + calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) + abs_indices = get_abs_indices(init_label_map, num_spixels_width) + + pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) + permuted_pixel_features = pixel_features.permute(0, 2, 1) + + for _ in range(n_iter): + dist_matrix = PairwiseDistFunction.apply( + pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) + + affinity_matrix = (-dist_matrix).softmax(1) + reshaped_affinity_matrix = affinity_matrix.reshape(-1) + + mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) + sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) + spixel_features = naive_sparse_bmm(sparse_abs_affinity, permuted_pixel_features) \ + / (torch.sparse.sum(sparse_abs_affinity, 2).to_dense()[..., None] + 1e-16) + + spixel_features = spixel_features.permute(0, 2, 1) + + hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) + + return sparse_abs_affinity, hard_labels, spixel_features + + +def ssn_iter(pixel_features, num_spixels, n_iter): + """ + computing assignment iterations + detailed process is in Algorithm 1, line 2 - 6 + + Args: + pixel_features: torch.Tensor + A Tensor of shape (B, C, H, W) + num_spixels: int + A number of superpixels + n_iter: int + A number of iterations + return_hard_label: bool + return hard assignment or not + """ + height, width = pixel_features.shape[-2:] + num_spixels_width = int(math.sqrt(num_spixels * width / height)) + num_spixels_height = int(math.sqrt(num_spixels * height / width)) + + # spixel_features: 10 * 202 * 64 + # init_label_map: 10 * 40000 + spixel_features, init_label_map = \ + calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) + # get indices of the 9 neighbors + abs_indices = get_abs_indices(init_label_map, num_spixels_width) + + # 10 * 202 * 40000 + pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) + # 10 * 40000 * 202 + permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous() + + for _ in range(n_iter): + # 10 * 9 * 40000 + dist_matrix = PairwiseDistFunction.apply( + pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) + + affinity_matrix = (-dist_matrix).softmax(1) + reshaped_affinity_matrix = affinity_matrix.reshape(-1) + + mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) + # 10 * 64 * 40000 + sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) + + abs_affinity = sparse_abs_affinity.to_dense().contiguous() + spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \ + / (abs_affinity.sum(2, keepdim=True) + 1e-16) + + spixel_features = spixel_features.permute(0, 2, 1).contiguous() + + + hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) + + return abs_affinity, hard_labels, spixel_features + +def ssn_iter2(pixel_features, num_spixels, n_iter, init_spixel_features, temp = 1): + """ + computing assignment iterations for second layer + + Args: + pixel_features: torch.Tensor + A Tensor of shape (B, C, N) + num_spixels: int + A number of superpixels + init_spixel_features: + A Tensor of shape (B, C, num_spixels) + """ + spixel_features = init_spixel_features.permute(0, 2, 1) + pixel_features = pixel_features.permute(0, 2, 1) + for _ in range(n_iter): + # compute distance to all spixel_features + dist = torch.cdist(pixel_features, spixel_features) # B, N, num_spixels + aff = (-dist * temp).softmax(-1).permute(0, 2, 1) # B, num_spixels, N + # compute new superpixels centers + spixel_features = torch.bmm(aff, pixel_features) / (aff.sum(2, keepdim=True) + 1e-6) # B, num_spixels, C + hard_labels = torch.argmax(aff, dim = 1) + return aff, hard_labels, spixel_features + diff --git a/libs/test_base.py b/libs/test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..54998fe2abd2f3487f9d8c28a7b757731c11f17e --- /dev/null +++ b/libs/test_base.py @@ -0,0 +1,135 @@ +""" +Testing base class. +""" +import torchvision.transforms as transforms +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch + +import numpy as np +import math + +from . import flow_transforms + +class TesterBase(): + def __init__(self, args): + cudnn.benchmark = True + + self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).cuda() + + self.args = args + + def init_dataset(self): + if self.args.dataset == 'BSD500': + from ..data import BSD500 + # ========== Data loading code ============== + input_transform = transforms.Compose([ + flow_transforms.ArrayToTensor(), + transforms.Normalize(mean=[0,0,0], std=[255,255,255]), + transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) + ]) + + val_input_transform = transforms.Compose([ + flow_transforms.ArrayToTensor(), + transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), + transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]) + ]) + + target_transform = transforms.Compose([ + flow_transforms.ArrayToTensor(), + ]) + + co_transform = flow_transforms.Compose([ + flow_transforms.CenterCrop((self.args.train_img_height , self.args.train_img_width)), + ]) + print("=> loading img pairs from '{}'".format(self.args.data)) + if self.args.crop_img == 0: + train_set, val_set = BSD500(self.args.data, + transform=input_transform, + val_transform = val_input_transform, + target_transform=target_transform) + else: + train_set, val_set = BSD500(self.args.data, + transform=input_transform, + val_transform = val_input_transform, + target_transform=target_transform, + co_transform=co_transform) + print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set))) + + + self.train_loader = torch.utils.data.DataLoader( + train_set, batch_size=self.args.batch_size, + num_workers=self.args.workers, pin_memory=True, shuffle=False, drop_last=True) + elif self.args.dataset == 'texture': + from ..data.texture_v3 import Dataset + dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height, test=True) + self.train_loader = torch.utils.data.DataLoader(dataset = dataset, + batch_size = self.args.batch_size, + num_workers = self.args.workers, + shuffle = False, + drop_last = True) + else: + from basicsr.data import create_dataloader, create_dataset + opt = {} + opt['dist'] = False + opt['phase'] = 'train' + + opt['name'] = 'DIV2K' + opt['type'] = 'PairedImageDataset' + opt['dataroot_gt'] = self.args.HR_dir + opt['dataroot_lq'] = self.args.LR_dir + opt['filename_tmpl'] = '{}' + opt['io_backend'] = dict(type='disk') + + opt['gt_size'] = self.args.train_img_height + opt['use_flip'] = True + opt['use_rot'] = True + + opt['use_shuffle'] = True + opt['num_worker_per_gpu'] = self.args.workers + opt['batch_size_per_gpu'] = self.args.batch_size + opt['scale'] = int(self.args.ratio) + + opt['dataset_enlarge_ratio'] = 1 + dataset = create_dataset(opt) + self.train_loader = create_dataloader( + dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None) + + def init_testing(self): + self.init_constant() + self.init_dataset() + self.define_model() + + def init_constant(self): + return + + def define_model(self): + raise NotImplementedError + + def display(self): + raise NotImplementedError + + def forward(self, iteration): + raise NotImplementedError + + def test(self): + args = self.args + for iteration, data in enumerate(self.train_loader): + print("Iteration: {}.".format(iteration)) + if args.dataset == 'BSD500': + image = data[0].cuda() + self.label = data[1].cuda() + self.gt = None + elif args.dataset == 'texture': + image = data[0].cuda() + self.image2 = data[1].cuda() + else: + image = data['lq'].cuda() + self.gt = data['gt'].cuda() + image = image.cuda() + self.image = image + self.forward() + self.display(iteration) + if iteration > args.niteration: + break + diff --git a/libs/train_base.py b/libs/train_base.py new file mode 100644 index 0000000000000000000000000000000000000000..567fbd8cf150155ba416d6ca53df5b849adfeca9 --- /dev/null +++ b/libs/train_base.py @@ -0,0 +1,230 @@ +"""Training base class +""" +import torchvision.transforms as transforms +import torchvision.utils as vutils +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch.fft +import torch + +import numpy as np +import argparse +import wandb +import math +import time +import os +from . import flow_transforms + +class TrainerBase(): + def __init__(self, args): + """ + Initialization function. + """ + cudnn.benchmark = True + + os.environ['WANDB_DIR'] = args.work_dir + args.use_wandb = (args.use_wandb == 1) + + if args.use_wandb: + wandb.login(key="d56eb81cd6396f0a181524ba214f488cf281e76b") + wandb.init(project=args.project_name, name=args.exp_name) + wandb.config.update(args) + + self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).cuda() + self.color_palette = np.loadtxt('data/palette.txt',dtype=np.uint8).reshape(-1,3) + + self.args = args + + def init_dataset(self): + """ + Initialize dataset + """ + if self.args.dataset == 'BSD500': + from ..data import BSD500 + # ========== Data loading code ============== + input_transform = transforms.Compose([ + flow_transforms.ArrayToTensor(), + transforms.Normalize(mean=[0,0,0], std=[255,255,255]), + transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) + ]) + + val_input_transform = transforms.Compose([ + flow_transforms.ArrayToTensor(), + transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), + transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]) + ]) + + target_transform = transforms.Compose([ + flow_transforms.ArrayToTensor(), + ]) + + co_transform = flow_transforms.Compose([ + flow_transforms.RandomCrop((self.args.train_img_height , self.args.train_img_width)), + flow_transforms.RandomVerticalFlip(), + flow_transforms.RandomHorizontalFlip() + ]) + print("=> loading img pairs from '{}'".format(self.args.data)) + train_set, val_set = BSD500(self.args.data, + transform=input_transform, + val_transform = val_input_transform, + target_transform=target_transform, + co_transform=co_transform, + bi_filter=True) + print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set))) + + self.train_loader = torch.utils.data.DataLoader( + train_set, batch_size=self.args.batch_size, + num_workers=self.args.workers, pin_memory=True, shuffle=True, drop_last=True) + elif self.args.dataset == 'texture': + from ..data.texture_v3 import Dataset + dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height) + self.train_loader = torch.utils.data.DataLoader(dataset = dataset, + batch_size = self.args.batch_size, + shuffle = True, + num_workers = self.args.workers, + drop_last = True) + elif self.args.dataset == 'DIV2K': + from basicsr.data import create_dataloader, create_dataset + opt = {} + opt['dist'] = False + opt['phase'] = 'train' + + opt['name'] = 'DIV2K' + opt['type'] = 'PairedImageDataset' + opt['dataroot_gt'] = self.args.HR_dir + opt['dataroot_lq'] = self.args.LR_dir + opt['filename_tmpl'] = '{}' + opt['io_backend'] = dict(type='disk') + + opt['gt_size'] = self.args.train_img_height + opt['use_flip'] = True + opt['use_rot'] = True + + opt['use_shuffle'] = True + opt['num_worker_per_gpu'] = self.args.workers + opt['batch_size_per_gpu'] = self.args.batch_size + opt['scale'] = int(self.args.ratio) + + opt['dataset_enlarge_ratio'] = 1 + dataset = create_dataset(opt) + self.train_loader = create_dataloader( + dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None) + else: + raise ValueError("Unknown dataset: {}.".format(self.args.dataset)) + + def init_training(self): + self.init_constant() + self.init_dataset() + self.define_model() + self.define_criterion() + self.define_optimizer() + + def adjust_learning_rate(self, iteration): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = self.args.lr * (0.95 ** (iteration // self.args.lr_decay_freq)) + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + + def logging(self, iteration, epoch): + print_str = "[{}/{}][{}/{}], ".format(iteration, len(self.train_loader), epoch, self.args.nepochs) + for k,v in self.losses.items(): + print_str += "{}: {:4f} ".format(k, v) + print_str += "time: {:2f}.".format(self.iter_time) + print(print_str) + + def get_sp_grid(self, H, W, G, R = 1): + W = int(W // R) + H = int(H // R) + if G > min(H, W): + raise ValueError('Grid size must be smaller than image size!') + grid = torch.from_numpy(np.arange(G**2)).view(1, 1, G, G) + grid = torch.cat([grid]*(int(math.ceil(W/G))), dim = -1) + grid = torch.cat([grid]*(int(math.ceil(H/G))), dim = -2) + grid = grid[:, :, :H, :W] + return grid.float() + + def save_network(self, name = None): + cpk = {} + cpk['epoch'] = self.epoch + cpk['lr'] = self.optimizer.param_groups[0]['lr'] + if hasattr(self.model, 'module'): + cpk['model'] = self.model.module.cpu().state_dict() + else: + cpk['model'] = self.model.cpu().state_dict() + if name is None: + out_path = os.path.join(self.args.out_dir, "cpk.pth") + else: + out_path = os.path.join(self.args.out_dir, name + ".pth") + torch.save(cpk, out_path) + self.model.cuda() + return + + def init_constant(self): + return + + def define_model(self): + raise NotImplementedError + + def define_criterion(self): + raise NotImplementedError + + def define_optimizer(self): + raise NotImplementedError + + def display(self): + raise NotImplementedError + + def forward(self): + raise NotImplementedError + + def train(self): + args = self.args + total_iteration = 0 + for epoch in range(args.nepochs): + self.epoch = epoch + for iteration, data in enumerate(self.train_loader): + if args.dataset == 'BSD500': + image = data[0].cuda() + self.label = data[1].cuda() + elif args.dataset == 'texture': + image = data[0].cuda() + self.image2 = data[1].cuda() + else: + image = data['lq'].cuda() + self.gt = data['gt'].cuda() + start_time = time.time() + total_iteration += 1 + self.optimizer.zero_grad() + image = image.cuda() + if args.dataset == 'BSD500': + self.image = image + self.mean_values + self.gt = self.image + else: + self.image = image + self.forward() + total_loss = 0 + for k,v in self.losses.items(): + if hasattr(args, '{}_wt'.format(k)): + total_loss += v * getattr(args, '{}_wt'.format(k)) + else: + total_loss += v + total_loss.backward() + self.optimizer.step() + end_time = time.time() + self.iter_time = end_time - start_time + + self.adjust_learning_rate(total_iteration) + + if((iteration + 1) % args.log_freq == 0): + self.logging(iteration, epoch) + if args.use_wandb: + wandb.log(self.losses) + + if(iteration % args.display_freq == 0): + example_images = self.display() + if args.use_wandb: + wandb.log({'images': example_images}) + + if((epoch + 1) % args.save_freq == 0): + self.save_network() + diff --git a/libs/transformer.py b/libs/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..79ea7129fa64d1dc1b2486a2a8575f3550f8f824 --- /dev/null +++ b/libs/transformer.py @@ -0,0 +1,286 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PositionEmbs(nn.Module): + def __init__(self, num_patches, emb_dim, dropout_rate=0.1): + super(PositionEmbs, self).__init__() + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim)) + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + + def forward(self, x): + out = x + self.pos_embedding + + if self.dropout: + out = self.dropout(out) + + return out + + +class MlpBlock(nn.Module): + """ Transformer Feed-Forward Block """ + def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1): + super(MlpBlock, self).__init__() + + # init layers + self.fc1 = nn.Linear(in_dim, mlp_dim) + self.fc2 = nn.Linear(mlp_dim, out_dim) + self.act = nn.GELU() + if dropout_rate > 0.0: + self.dropout1 = nn.Dropout(dropout_rate) + self.dropout2 = nn.Dropout(dropout_rate) + else: + self.dropout1 = None + self.dropout2 = None + + def forward(self, x): + + out = self.fc1(x) + out = self.act(out) + if self.dropout1: + out = self.dropout1(out) + + out = self.fc2(out) + if self.dropout2: + out = self.dropout2(out) + return out + + +class LinearGeneral(nn.Module): + def __init__(self, in_dim=(768,), feat_dim=(12, 64)): + super(LinearGeneral, self).__init__() + + #self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim)) + self.weight = torch.randn(*in_dim, *feat_dim) + self.weight.normal_(0, 0.02) + self.weight = nn.Parameter(self.weight) + self.bias = nn.Parameter(torch.zeros(*feat_dim)) + + def forward(self, x, dims): + a = torch.tensordot(x, self.weight, dims=dims) + self.bias + return a + + +class SelfAttention(nn.Module): + def __init__(self, in_dim, heads=8, dropout_rate=0.1): + super(SelfAttention, self).__init__() + self.heads = heads + self.head_dim = in_dim // heads + self.scale = self.head_dim ** 0.5 + + self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim)) + self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim)) + self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim)) + self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,)) + + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + + def forward(self, x, vis_attn = False): + b, n, _ = x.shape + + q = self.query(x, dims=([2], [0])) + k = self.key(x, dims=([2], [0])) + v = self.value(x, dims=([2], [0])) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale + attn_weights = F.softmax(attn_weights, dim=-1) + out = torch.matmul(attn_weights, v) + out = out.permute(0, 2, 1, 3) + + out = self.out(out, dims=([2, 3], [0, 1])) + + if not vis_attn: + return out + else: + return out, attn_weights + +class EncoderBlock(nn.Module): + def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1, normalize = 'layer_norm'): + super(EncoderBlock, self).__init__() + + if normalize == 'layer_norm': + self.norm1 = nn.LayerNorm(in_dim) + self.norm2 = nn.LayerNorm(in_dim) + elif normalize == 'group_norm': + self.norm1 = Normalize(in_dim) + self.norm2 = Normalize(in_dim) + + self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate) + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) + + def forward(self, x, vis_attn = False): + residual = x + out = self.norm1(x) + if vis_attn: + out, attn_weights = self.attn(out, vis_attn) + else: + out = self.attn(out, vis_attn) + if self.dropout: + out = self.dropout(out) + out += residual + residual = out + + out = self.norm2(out) + out = self.mlp(out) + out += residual + if vis_attn: + return out, attn_weights + else: + return out + +class Encoder(nn.Module): + def __init__(self, num_patches, emb_dim, mlp_dim, num_layers=12, num_heads=12, dropout_rate=0.1, attn_dropout_rate=0.0): + super(Encoder, self).__init__() + + # positional embedding + self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate) + + # encoder blocks + in_dim = emb_dim + self.encoder_layers = nn.ModuleList() + for i in range(num_layers): + layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, attn_dropout_rate) + self.encoder_layers.append(layer) + self.norm = nn.LayerNorm(in_dim) + + def forward(self, x): + + out = self.pos_embedding(x) + + for layer in self.encoder_layers: + out = layer(out) + + out = self.norm(out) + return out + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, + image_size=(256, 256), + patch_size=(16, 16), + emb_dim=768, + mlp_dim=3072, + num_heads=12, + num_layers=12, + num_classes=1000, + attn_dropout_rate=0.0, + dropout_rate=0.1, + feat_dim=None): + super(VisionTransformer, self).__init__() + h, w = image_size + + # embedding layer + fh, fw = patch_size + gh, gw = h // fh, w // fw + num_patches = gh * gw + self.embedding = nn.Conv2d(3, emb_dim, kernel_size=(fh, fw), stride=(fh, fw)) + # class token + self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) + + # transformer + self.transformer = Encoder( + num_patches=num_patches, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + num_layers=num_layers, + num_heads=num_heads, + dropout_rate=dropout_rate, + attn_dropout_rate=attn_dropout_rate) + + # classfier + self.classifier = nn.Linear(emb_dim, num_classes) + + def forward(self, x): + emb = self.embedding(x) # (n, c, gh, gw) + emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c) + b, h, w, c = emb.shape + emb = emb.reshape(b, h * w, c) + + # prepend class token + cls_token = self.cls_token.repeat(b, 1, 1) + emb = torch.cat([cls_token, emb], dim=1) + + # transformer + feat = self.transformer(emb) + + # classifier + logits = self.classifier(feat[:, 0]) + return logits + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +if __name__ == '__main__': + model = VisionTransformer(num_layers=2) + import pdb; pdb.set_trace() + x = torch.randn((2, 3, 256, 256)) + out = model(x) \ No newline at end of file diff --git a/libs/transformer_cluster.py b/libs/transformer_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..9789418c44c37d2f77bb0b6ff4fee20237a66ba0 --- /dev/null +++ b/libs/transformer_cluster.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PositionEmbs(nn.Module): + def __init__(self, num_patches, emb_dim, dropout_rate=0.1): + super(PositionEmbs, self).__init__() + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim)) + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + + def forward(self, x): + out = x + self.pos_embedding + + if self.dropout: + out = self.dropout(out) + + return out + + +class MlpBlock(nn.Module): + """ Transformer Feed-Forward Block """ + def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1): + super(MlpBlock, self).__init__() + + # init layers + self.fc1 = nn.Linear(in_dim, mlp_dim) + self.fc2 = nn.Linear(mlp_dim, out_dim) + self.act = nn.GELU() + if dropout_rate > 0.0: + self.dropout1 = nn.Dropout(dropout_rate) + self.dropout2 = nn.Dropout(dropout_rate) + else: + self.dropout1 = None + self.dropout2 = None + + def forward(self, x): + + out = self.fc1(x) + out = self.act(out) + if self.dropout1: + out = self.dropout1(out) + + out = self.fc2(out) + if self.dropout2: + out = self.dropout2(out) + return out + + +class LinearGeneral(nn.Module): + def __init__(self, in_dim=(768,), feat_dim=(12, 64)): + super(LinearGeneral, self).__init__() + + self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim)) + self.bias = nn.Parameter(torch.zeros(*feat_dim)) + + def forward(self, x, dims): + a = torch.tensordot(x, self.weight, dims=dims) + self.bias + return a + + +class SelfAttention(nn.Module): + def __init__(self, in_dim, heads=8, dropout_rate=0.1): + super(SelfAttention, self).__init__() + self.heads = heads + self.head_dim = in_dim // heads + self.scale = self.head_dim ** 0.5 + + self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim)) + self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim)) + self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim)) + self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,)) + + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + + self.cluster_mlp = nn.Sequential(nn.Linear(256 * 100, 64 * 100), + nn.LeakyReLU(0.2), + nn.Linear(64 * 100, 8 * 100)) + + def forward(self, x): + b, n, _ = x.shape + + q = self.query(x, dims=([2], [0])) + q = self.cluster_mlp(q.view(b, -1)).view(b, 8, 1, 100) + k = self.key(x, dims=([2], [0])) + v = self.value(x, dims=([2], [0])) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale + attn_weights = F.softmax(attn_weights, dim=-1) + out = torch.matmul(attn_weights, v) + out = out.permute(0, 2, 1, 3) + + out = self.out(out, dims=([2, 3], [0, 1])) + + return out + + +class EncoderBlock(nn.Module): + def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1): + super(EncoderBlock, self).__init__() + + self.norm1 = nn.LayerNorm(in_dim) + self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate) + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + self.norm2 = nn.LayerNorm(in_dim) + self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) + + def forward(self, x): + residual = x + out = self.norm1(x) + out = self.attn(out) + if self.dropout: + out = self.dropout(out) + #out += residual + residual = out + + out = self.norm2(out) + out = self.mlp(out) + out += residual + return out + + +class Encoder(nn.Module): + def __init__(self, num_patches, emb_dim, mlp_dim, num_layers=12, num_heads=12, dropout_rate=0.1, attn_dropout_rate=0.0): + super(Encoder, self).__init__() + + # positional embedding + self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate) + + # encoder blocks + in_dim = emb_dim + self.encoder_layers = nn.ModuleList() + for i in range(num_layers): + layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, attn_dropout_rate) + self.encoder_layers.append(layer) + self.norm = nn.LayerNorm(in_dim) + + def forward(self, x): + + out = self.pos_embedding(x) + + for layer in self.encoder_layers: + out = layer(out) + + out = self.norm(out) + return out + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, + image_size=(256, 256), + patch_size=(16, 16), + emb_dim=768, + mlp_dim=3072, + num_heads=12, + num_layers=12, + num_classes=1000, + attn_dropout_rate=0.0, + dropout_rate=0.1, + feat_dim=None): + super(VisionTransformer, self).__init__() + h, w = image_size + + # embedding layer + fh, fw = patch_size + gh, gw = h // fh, w // fw + num_patches = gh * gw + self.embedding = nn.Conv2d(3, emb_dim, kernel_size=(fh, fw), stride=(fh, fw)) + # class token + self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) + + # transformer + self.transformer = Encoder( + num_patches=num_patches, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + num_layers=num_layers, + num_heads=num_heads, + dropout_rate=dropout_rate, + attn_dropout_rate=attn_dropout_rate) + + # classfier + self.classifier = nn.Linear(emb_dim, num_classes) + + def forward(self, x): + emb = self.embedding(x) # (n, c, gh, gw) + emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c) + b, h, w, c = emb.shape + emb = emb.reshape(b, h * w, c) + + # prepend class token + cls_token = self.cls_token.repeat(b, 1, 1) + emb = torch.cat([cls_token, emb], dim=1) + + # transformer + feat = self.transformer(emb) + + # classifier + logits = self.classifier(feat[:, 0]) + return logits + +if __name__ == '__main__': + model = VisionTransformer(num_layers=2) + import pdb; pdb.set_trace() + x = torch.randn((2, 3, 256, 256)) + out = model(x) \ No newline at end of file diff --git a/libs/utils.py b/libs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9987acdae9ef4a5f088895b6d7e542904e93fd66 --- /dev/null +++ b/libs/utils.py @@ -0,0 +1,166 @@ +import torch +import torch.nn.functional as F + +import numpy as np +from scipy.io import loadmat + +def init_spixel_grid(args, b_train=True, ratio = 1, downsize = 16): + curr_img_height = args.crop_size + curr_img_width = args.crop_size + + # pixel coord + all_h_coords = np.arange(0, curr_img_height, 1) + all_w_coords = np.arange(0, curr_img_width, 1) + curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij')) + + coord_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]]) + + all_XY_feat = (torch.from_numpy( + np.tile(coord_tensor, (1, 1, 1, 1)).astype(np.float32)).cuda()) + + return all_XY_feat + +def label2one_hot_torch(labels, C=14): + """ Converts an integer label torch.autograd.Variable to a one-hot Variable. + + Args: + labels(tensor) : segmentation label + C (integer) : number of classes in labels + + Returns: + target (tensor) : one-hot vector of the input label + + Shape: + labels: (B, 1, H, W) + target: (B, N, H, W) + """ + b,_, h, w = labels.shape + one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels) + target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type + + return target.type(torch.float32) + +colors = loadmat('data/color150.mat')['colors'] +colors = np.concatenate((colors, colors, colors, colors)) + +def unique(ar, return_index=False, return_inverse=False, return_counts=False): + ar = np.asanyarray(ar).flatten() + + optional_indices = return_index or return_inverse + optional_returns = optional_indices or return_counts + + if ar.size == 0: + if not optional_returns: + ret = ar + else: + ret = (ar,) + if return_index: + ret += (np.empty(0, np.bool),) + if return_inverse: + ret += (np.empty(0, np.bool),) + if return_counts: + ret += (np.empty(0, np.intp),) + return ret + if optional_indices: + perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') + aux = ar[perm] + else: + ar.sort() + aux = ar + flag = np.concatenate(([True], aux[1:] != aux[:-1])) + + if not optional_returns: + ret = aux[flag] + else: + ret = (aux[flag],) + if return_index: + ret += (perm[flag],) + if return_inverse: + iflag = np.cumsum(flag) - 1 + inv_idx = np.empty(ar.shape, dtype=np.intp) + inv_idx[perm] = iflag + ret += (inv_idx,) + if return_counts: + idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) + ret += (np.diff(idx),) + return ret + +def colorEncode(labelmap, mode='RGB'): + labelmap = labelmap.astype('int') + labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), + dtype=np.uint8) + for label in unique(labelmap): + if label < 0: + continue + labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ + np.tile(colors[label], + (labelmap.shape[0], labelmap.shape[1], 1)) + + if mode == 'BGR': + return labelmap_rgb[:, :, ::-1] + else: + return labelmap_rgb + +def get_edges(sp_label, sp_num): + # This function returns a (hw) * (hw) matrix N. + # If Nij = 1, then superpixel i and j are neighbors + # Otherwise Nij = 0. + top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :] + left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:] + top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:] + top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1] + n_affs = [] + edge_indices = [] + for i in range(sp_label.shape[0]): + # change to torch.ones below to include self-loop in graph + n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).cuda() + # top/bottom + top_i = top[i].squeeze() + x, y = torch.nonzero(top_i, as_tuple = True) + sp1 = sp_label[i, :, x, y].squeeze().long() + sp2 = sp_label[i, :, x+1, y].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + # left/right + left_i = left[i].squeeze() + try: + x, y = torch.nonzero(left_i, as_tuple = True) + except: + import pdb; pdb.set_trace() + sp1 = sp_label[i, :, x, y].squeeze().long() + sp2 = sp_label[i, :, x, y+1].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + # top left + top_left_i = top_left[i].squeeze() + x, y = torch.nonzero(top_left_i, as_tuple = True) + sp1 = sp_label[i, :, x, y].squeeze().long() + sp2 = sp_label[i, :, x+1, y+1].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + # top right + top_right_i = top_right[i].squeeze() + x, y = torch.nonzero(top_right_i, as_tuple = True) + sp1 = sp_label[i, :, x, y+1].squeeze().long() + sp2 = sp_label[i, :, x+1, y].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + n_affs.append(n_aff) + edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True)) + edge_indices.append(edge_index.cuda()) + return edge_indices + + +def draw_color_seg(seg): + seg = seg.detach().cpu().numpy() + color_ = [] + for i in range(seg.shape[0]): + colori = colorEncode(seg[i].squeeze()) + colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) + color_.append(colori) + color_ = torch.stack(color_) + return color_ diff --git a/libs/vq_functions.py b/libs/vq_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..54d14cdecc404633a110a6305908e9a04b029099 --- /dev/null +++ b/libs/vq_functions.py @@ -0,0 +1,70 @@ +"""Module functions for the VQVAE + - Adopted from https://github.com/ritheshkumar95/pytorch-vqvae +""" +import torch +from torch.autograd import Function + +class VectorQuantization(Function): + @staticmethod + def forward(ctx, inputs, codebook): + with torch.no_grad(): + embedding_size = codebook.size(1) + inputs_size = inputs.size() + inputs_flatten = inputs.view(-1, embedding_size) + + codebook_sqr = torch.sum(codebook ** 2, dim=1) + inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) + + # Compute the distances to the codebook + distances = torch.addmm(codebook_sqr + inputs_sqr, + inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0) + + _, indices_flatten = torch.min(distances, dim=1) + indices = indices_flatten.view(*inputs_size[:-1]) + ctx.mark_non_differentiable(indices) + + return indices + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError('Trying to call `.grad()` on graph containing ' + '`VectorQuantization`. The function `VectorQuantization` ' + 'is not differentiable. Use `VectorQuantizationStraightThrough` ' + 'if you want a straight-through estimator of the gradient.') + +class VectorQuantizationStraightThrough(Function): + @staticmethod + def forward(ctx, inputs, codebook): + indices = vq(inputs, codebook) + indices_flatten = indices.view(-1) + ctx.save_for_backward(indices_flatten, codebook) + ctx.mark_non_differentiable(indices_flatten) + + codes_flatten = torch.index_select(codebook, dim=0, + index=indices_flatten) + codes = codes_flatten.view_as(inputs) + + return (codes, indices_flatten) + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + # Straight-through estimator + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + embedding_size = codebook.size(1) + + grad_output_flatten = (grad_output.contiguous() + .view(-1, embedding_size)) + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output_flatten) + + return (grad_inputs, grad_codebook) + +vq = VectorQuantization.apply +vq_st = VectorQuantizationStraightThrough.apply +__all__ = [vq, vq_st] \ No newline at end of file diff --git a/libs/vqperceptual.py b/libs/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..3b36c56ce9a36e8659d8ce3434e8b0e5e850f92f --- /dev/null +++ b/libs/vqperceptual.py @@ -0,0 +1,252 @@ +"""VQGAN Loss + - Adapted from https://github.com/CompVis/taming-transformers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .discriminator import NLayerDiscriminator, weights_init +from .blocks import LossCriterion, LossCriterionMask + +class DummyLoss(nn.Module): + def __init__(self): + super().__init__() + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + +def fft_loss(pred, tgt): + return ((torch.fft.fftn(pred, dim=(-2,-1)) - torch.fft.fftn(tgt, dim=(-2,-1)))).abs().mean() + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, model_path, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=0.8, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", rec_loss="FFT", + style_layers = [], content_layers = ['r41']): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LossCriterion(style_layers, content_layers, + 0, perceptual_weight, + model_path = model_path) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} and {rec_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + self.rec_loss = rec_loss + self.perceptual_weight = perceptual_weight + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train"): + if self.rec_loss == "L1": + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean() + elif self.rec_loss == "MSE": + rec_loss = F.mse_loss(reconstructions, inputs) + elif self.rec_loss == "FFT": + rec_loss = fft_loss(inputs, reconstructions) + elif self.rec_loss is None: + rec_loss = 0 + else: + raise ValueError("Unkown reconstruction loss, choices are [FFT, L1]") + + if self.perceptual_weight > 0: + loss_dict = self.perceptual_loss(reconstructions, inputs, style = False) + p_loss = loss_dict['content'] + rec_loss = rec_loss + p_loss + else: + p_loss = torch.zeros(1).cuda() + nll_loss = rec_loss + + # adversarial loss for both branches + if optimizer_idx == 0: + log = {} + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + # generator update + if disc_factor > 0: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + #assert not self.training + d_weight = torch.tensor(0.0) + loss = nll_loss + d_weight * disc_factor * g_loss + log["d_weight"] = d_weight.detach() + log["disc_factor"] = torch.tensor(disc_factor) + log["g_loss"] = g_loss.detach().mean() + else: + loss = nll_loss + + log["total_loss"] = loss.clone().detach().mean() + log["nll_loss"] = nll_loss.detach().mean() + log["rec_loss"] = rec_loss.detach().mean() + log["p_loss"] = p_loss.detach().mean() + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"disc_loss": d_loss.clone().detach().mean(), + "logits_real": logits_real.detach().mean(), + "logits_fake": logits_fake.detach().mean() + } + return d_loss, log + +class LPIPSWithDiscriminatorMask(nn.Module): + def __init__(self, disc_start, model_path, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=0.8, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", rec_loss="FFT", + style_layers = [], content_layers = ['r41']): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LossCriterionMask(style_layers, content_layers, + 0.2, perceptual_weight, + model_path = model_path) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} and {rec_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + self.rec_loss = rec_loss + self.perceptual_weight = perceptual_weight + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, optimizer_idx, + global_step, mask, last_layer=None, cond=None, split="train"): + if self.rec_loss == "L1": + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean() + elif self.rec_loss == "MSE": + rec_loss = F.mse_loss(reconstructions, inputs) + elif self.rec_loss == "FFT": + rec_loss = fft_loss(inputs, reconstructions) + elif self.rec_loss is None: + rec_loss = 0 + else: + raise ValueError("Unkown reconstruction loss, choices are [FFT, L1]") + + if self.perceptual_weight > 0: + loss_dict = self.perceptual_loss(reconstructions, inputs, mask, style = True) + p_loss = loss_dict['content'] + s_loss = loss_dict['style'] + rec_loss = rec_loss + p_loss + s_loss + else: + p_loss = torch.zeros(1).cuda() + nll_loss = rec_loss + + # adversarial loss for both branches + if optimizer_idx == 0: + # generator update + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + #assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + + log = {"total_loss": loss.clone().detach().mean(), + "nll_loss": nll_loss.detach().mean(), + "rec_loss": rec_loss.detach().mean(), + "p_loss": p_loss.detach().mean(), + "s_loss": s_loss, + "d_weight": d_weight.detach(), + "disc_factor": torch.tensor(disc_factor), + "g_loss": g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"disc_loss": d_loss.clone().detach().mean(), + "logits_real": logits_real.detach().mean(), + "logits_fake": logits_fake.detach().mean() + } + return d_loss, log \ No newline at end of file diff --git a/models/week0417/__pycache__/loss.cpython-37.pyc b/models/week0417/__pycache__/loss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbd0d288c382b3eb2e5dfc11e7fda3882a6755fb Binary files /dev/null and b/models/week0417/__pycache__/loss.cpython-37.pyc differ diff --git a/models/week0417/__pycache__/model.cpython-37.pyc b/models/week0417/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58cc3b0210a50d5705c3ade66ee80d58218e0b02 Binary files /dev/null and b/models/week0417/__pycache__/model.cpython-37.pyc differ diff --git a/models/week0417/__pycache__/model.cpython-38.pyc b/models/week0417/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bc322a5e947bf9b0b4813168d88e9aaad541b93 Binary files /dev/null and b/models/week0417/__pycache__/model.cpython-38.pyc differ diff --git a/models/week0417/__pycache__/nnutils.cpython-37.pyc b/models/week0417/__pycache__/nnutils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38697517f94ff4149aa43527f196baaecade692 Binary files /dev/null and b/models/week0417/__pycache__/nnutils.cpython-37.pyc differ diff --git a/models/week0417/__pycache__/nnutils.cpython-38.pyc b/models/week0417/__pycache__/nnutils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..789841ac94321ed5e9130c388cedc5f3b2f41c47 Binary files /dev/null and b/models/week0417/__pycache__/nnutils.cpython-38.pyc differ diff --git a/models/week0417/__pycache__/taming_blocks.cpython-37.pyc b/models/week0417/__pycache__/taming_blocks.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8fda340e84c6d37d2427d2b3ad8ace8abd29d3 Binary files /dev/null and b/models/week0417/__pycache__/taming_blocks.cpython-37.pyc differ diff --git a/models/week0417/__pycache__/taming_blocks.cpython-38.pyc b/models/week0417/__pycache__/taming_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58764590e88f070da9f03c7565d9ed618ef9c398 Binary files /dev/null and b/models/week0417/__pycache__/taming_blocks.cpython-38.pyc differ diff --git a/models/week0417/dataset.py b/models/week0417/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1f28fbf0b05a81e1674a50b7a5adcee017c5a3 --- /dev/null +++ b/models/week0417/dataset.py @@ -0,0 +1,150 @@ +import torch +import torch.utils.data as data +import torchvision.transforms as transforms + +import os +import numpy as np +from PIL import Image +from glob import glob +from skimage.segmentation import slic +import torchvision.transforms.functional as TF +from scipy.io import loadmat +import random +import cv2 + +import sys +sys.path.append("../..") + +def label2one_hot_torch(labels, C=14): + """ Converts an integer label torch.autograd.Variable to a one-hot Variable. + + Args: + labels(tensor) : segmentation label + C (integer) : number of classes in labels + + Returns: + target (tensor) : one-hot vector of the input label + + Shape: + labels: (B, 1, H, W) + target: (B, N, H, W) + """ + b,_, h, w = labels.shape + one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels) + target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type + + return target.type(torch.float32) + +class Dataset(data.Dataset): + def __init__(self, data_dir, crop_size = 128, test=False, + sp_num = 256, slic = True, preprocess_name = False, + gt_label = False, label_path = None, test_time = False, + img_path = None): + super(Dataset, self).__init__() + ext = ["*.jpg"] + dl = [] + self.test = test + self.test_time = test_time + + [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext] + + data_list = sorted(dl) + self.data_list = data_list + self.sp_num = sp_num + self.slic = slic + + self.crop = transforms.CenterCrop(size = (crop_size, crop_size)) + self.crop_size = crop_size + self.test = test + + self.gt_label = gt_label + if gt_label: + self.label_path = label_path + + self.img_path = img_path + + def preprocess_label(self, seg): + segs = label2one_hot_torch(seg.unsqueeze(0), C = seg.max() + 1) + new_seg = [] + for cnt in range(seg.max() + 1): + if segs[0, cnt].sum() > 0: + new_seg.append(segs[0, cnt]) + new_seg = torch.stack(new_seg) + return torch.argmax(new_seg, dim = 0) + + def __getitem__(self, index): + if self.img_path is None: + data_path = self.data_list[index] + else: + data_path = self.img_path + rgb_img = Image.open(data_path) + imgH, imgW = rgb_img.size + + if self.gt_label: + img_name = data_path.split("/")[-1].split("_")[0] + mat_path = os.path.join(self.label_path, data_path.split('/')[-2], img_name.replace('.jpg', '.mat')) + mat = loadmat(mat_path) + max_label_num = 0 + final_seg = None + for i in range(len(mat['groundTruth'][0])): + seg = mat['groundTruth'][0][i][0][0][0] + if len(np.unique(seg)) > max_label_num: + max_label_num = len(np.unique(seg)) + final_seg = seg + seg = torch.from_numpy(final_seg.astype(np.float32)) + segs = seg.long().unsqueeze(0) + + if self.img_path is None: + i, j, h, w = transforms.RandomCrop.get_params(rgb_img, output_size=(self.crop_size, self.crop_size)) + else: + i = 40; j = 40; h = self.crop_size; w = self.crop_size + rgb_img = TF.crop(rgb_img, i, j, h, w) + if self.gt_label: + segs = TF.crop(segs, i, j, h, w) + segs = self.preprocess_label(segs) + + if self.slic: + sp_num = self.sp_num + # compute superpixel + slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) + slic_i = torch.from_numpy(slic_i) + slic_i[slic_i >= sp_num] = sp_num - 1 + oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() + + rgb_img = TF.to_tensor(rgb_img) + if rgb_img.shape[0] == 1: + rgb_img = rgb_img.repeat(3, 1, 1) + rgb_img = rgb_img[:3, :, :] + + rets = [] + rets.append(rgb_img) + if self.slic: + rets.append(oh) + rets.append(data_path.split("/")[-1]) + rets.append(index) + if self.gt_label: + rets.append(segs.view(1, segs.shape[-2], segs.shape[-1])) + return rets + + def __len__(self): + return len(self.data_list) + +if __name__ == '__main__': + import torchvision.utils as vutils + dataset = Dataset('/home/xtli/DATA/texture_data/', + sampled_num=3000) + loader_ = torch.utils.data.DataLoader(dataset = dataset, + batch_size = 1, + shuffle = True, + num_workers = 1, + drop_last = True) + loader = iter(loader_) + img, points, pixs = loader.next() + + crop_size = 128 + canvas = torch.zeros((1, 3, crop_size, crop_size)) + for i in range(points.shape[-2]): + p = (points[0, i] + 1) / 2.0 * (crop_size - 1) + canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i] + vutils.save_image(canvas, 'canvas.png') + vutils.save_image(img, 'img.png') diff --git a/models/week0417/focal_frequency_loss.py b/models/week0417/focal_frequency_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7bd5d0f466ad1841cf5f5ecb3b7e090f538ef5 --- /dev/null +++ b/models/week0417/focal_frequency_loss.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn + + +class FocalFrequencyLoss(nn.Module): + """The torch.nn.Module class that implements focal frequency loss - a + frequency domain loss function for optimizing generative models. + + Ref: + Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021. + + + Args: + loss_weight (float): weight for focal frequency loss. Default: 1.0 + alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0 + patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1 + ave_spectrum (bool): whether to use minibatch average spectrum. Default: False + log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False + batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False + """ + + def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False): + super(FocalFrequencyLoss, self).__init__() + self.loss_weight = loss_weight + self.alpha = alpha + self.patch_factor = patch_factor + self.ave_spectrum = ave_spectrum + self.log_matrix = log_matrix + self.batch_matrix = batch_matrix + + def tensor2freq(self, x): + # crop image patches + patch_factor = self.patch_factor + _, _, h, w = x.shape + assert h % patch_factor == 0 and w % patch_factor == 0, ( + 'Patch factor should be divisible by image height and width') + patch_list = [] + patch_h = h // patch_factor + patch_w = w // patch_factor + for i in range(patch_factor): + for j in range(patch_factor): + patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w]) + + # stack to patch tensor + y = torch.stack(patch_list, 1) + + # perform 2D DFT (real-to-complex, orthonormalization) + return torch.rfft(y, 2, onesided=False, normalized=True) + + def loss_formulation(self, recon_freq, real_freq, matrix=None): + # spectrum weight matrix + if matrix is not None: + # if the matrix is predefined + weight_matrix = matrix.detach() + else: + # if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance + matrix_tmp = (recon_freq - real_freq) ** 2 + matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha + + # whether to adjust the spectrum weight matrix by logarithm + if self.log_matrix: + matrix_tmp = torch.log(matrix_tmp + 1.0) + + # whether to calculate the spectrum weight matrix using batch-based statistics + if self.batch_matrix: + matrix_tmp = matrix_tmp / matrix_tmp.max() + else: + matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None] + + matrix_tmp[torch.isnan(matrix_tmp)] = 0.0 + matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0) + weight_matrix = matrix_tmp.clone().detach() + + assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, ( + 'The values of spectrum weight matrix should be in the range [0, 1], ' + 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item())) + + # frequency distance using (squared) Euclidean distance + tmp = (recon_freq - real_freq) ** 2 + freq_distance = tmp[..., 0] + tmp[..., 1] + + # dynamic spectrum weighting (Hadamard product) + loss = weight_matrix * freq_distance + return torch.mean(loss) + + def forward(self, pred, target, matrix=None, **kwargs): + """Forward function to calculate focal frequency loss. + + Args: + pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor. + target (torch.Tensor): of shape (N, C, H, W). Target tensor. + matrix (torch.Tensor, optional): Element-wise spectrum weight matrix. + Default: None (If set to None: calculated online, dynamic). + """ + pred_freq = self.tensor2freq(pred) + target_freq = self.tensor2freq(target) + + # whether to use minibatch average spectrum + if self.ave_spectrum: + pred_freq = torch.mean(pred_freq, 0, keepdim=True) + target_freq = torch.mean(target_freq, 0, keepdim=True) + + # calculate focal frequency loss + return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight diff --git a/models/week0417/json/ab1_no_noise.json b/models/week0417/json/ab1_no_noise.json new file mode 100644 index 0000000000000000000000000000000000000000..65e19641ac6b0f7afab3c0ad117a3c0a2d9e9eeb --- /dev/null +++ b/models/week0417/json/ab1_no_noise.json @@ -0,0 +1,26 @@ +{ + "exp_name": "04-17/ab1_no_noise", + "batch_size": 2, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model", + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 100, + "add_clustering_epoch": 200, + "add_texture_epoch": 300, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave", + "sine_weight": 1, + "spatial_code_dim": 32, + "project_name": "March", + "use_wandb": 1, + "data_path": "/workspace/BSR_processed/train", + "model_path": "/data/weights/models" +} diff --git a/models/week0417/json/ab2_noise_injection.json b/models/week0417/json/ab2_noise_injection.json new file mode 100644 index 0000000000000000000000000000000000000000..fe3fafa6d8bcc273f0e7462300611679480eb5a5 --- /dev/null +++ b/models/week0417/json/ab2_noise_injection.json @@ -0,0 +1,26 @@ +{ + "exp_name": "04-17/ab2_noise_injection", + "batch_size": 2, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model_noise_injection", + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 100, + "add_clustering_epoch": 200, + "add_texture_epoch": 300, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave", + "sine_weight": 1, + "spatial_code_dim": 32, + "project_name": "March", + "use_wandb": 1, + "data_path": "/workspace/BSR_processed/train", + "model_path": "/data/weights/models" +} diff --git a/models/week0417/json/ab3_multi_scale.json b/models/week0417/json/ab3_multi_scale.json new file mode 100644 index 0000000000000000000000000000000000000000..3ca20e8dc1b97dd7572a328b8170df0f8f6aca33 --- /dev/null +++ b/models/week0417/json/ab3_multi_scale.json @@ -0,0 +1,26 @@ +{ + "exp_name": "04-17/ab3_multi_scale", + "batch_size": 2, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 192, + "model_name": "model_multi_scale", + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 100, + "add_clustering_epoch": 200, + "add_texture_epoch": 300, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "project_name": "March", + "use_wandb": 1, + "data_path": "/workspace/BSR_processed/train", + "model_path": "/data/weights/models" +} diff --git a/models/week0417/json/ab4_multi_scale_noise_injection.json b/models/week0417/json/ab4_multi_scale_noise_injection.json new file mode 100644 index 0000000000000000000000000000000000000000..77524ad672d8e34137cf2473204c50ad6515e4db --- /dev/null +++ b/models/week0417/json/ab4_multi_scale_noise_injection.json @@ -0,0 +1,26 @@ +{ + "exp_name": "04-17/ab4_multi_scale_noise_injection", + "batch_size": 2, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 192, + "model_name": "model_multi_scale_noise_injection", + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 100, + "add_clustering_epoch": 200, + "add_texture_epoch": 300, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave", + "sine_weight": 1, + "spatial_code_dim": 32, + "project_name": "March", + "use_wandb": 1, + "data_path": "/workspace/BSR_processed/train", + "model_path": "/data/weights/models" +} diff --git a/models/week0417/json/fcn.json b/models/week0417/json/fcn.json new file mode 100644 index 0000000000000000000000000000000000000000..47f44079b490da6476d787e94d2a64c26edbc331 --- /dev/null +++ b/models/week0417/json/fcn.json @@ -0,0 +1,28 @@ +{ + "exp_name": "04-21/gcn", + "batch_size": 2, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model_fcn", + "save_freq": 1000, + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 100, + "add_clustering_epoch": 200, + "add_texture_epoch": 300, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "project_name": "March", + "use_wandb": 1, + "data_path": "/workspace/BSR_processed/train", + "model_path": "/data/weights/models", + "pretrained_fcn": "/data/superpixel_fcn/pretrain_ckpt/SpixelNet_bsd_ckpt.tar" +} diff --git a/models/week0417/json/fcn_ft.json b/models/week0417/json/fcn_ft.json new file mode 100644 index 0000000000000000000000000000000000000000..4f344dce13b8dfc5d0923dbdf634e0a460aaaeaf --- /dev/null +++ b/models/week0417/json/fcn_ft.json @@ -0,0 +1,28 @@ +{ + "exp_name": "04-21/gcn_ft", + "batch_size": 1, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model_fcn_no_rec", + "save_freq": 1000, + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 0, + "add_clustering_epoch": 0, + "add_texture_epoch": 0, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "pretrained_fcn": "/home/xli/Documents/github/superpixel_fcn/pretrain_ckpt/SpixelNet_bsd_ckpt.tar", + "work_dir": "/home/xli/WORKDIR/", + "data_path": "/home/xli/Data/BSR_processed/train", + "pretrained_path": "/home/xli/WORKDIR/04-21/gcn/cpk.pth", + "nepochs": 20 +} diff --git a/models/week0417/json/single_scale_grouping.json b/models/week0417/json/single_scale_grouping.json new file mode 100644 index 0000000000000000000000000000000000000000..64192a4a8f962e461a0c76693d793a5b3facc7e7 --- /dev/null +++ b/models/week0417/json/single_scale_grouping.json @@ -0,0 +1,28 @@ +{ + "exp_name": "04-15/single_scale_grouping_resume", + "batch_size": 2, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model", + "save_freq": 1000, + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 0, + "add_clustering_epoch": 0, + "add_texture_epoch": 0, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "project_name": "March", + "use_wandb": 1, + "data_path": "/workspace/BSR_processed/train", + "model_path": "/data/weights/models", + "pretrained_path": "/data/workdir/04-15/single_scale_grouping/cpk.pth" +} diff --git a/models/week0417/json/single_scale_grouping_ft.json b/models/week0417/json/single_scale_grouping_ft.json new file mode 100644 index 0000000000000000000000000000000000000000..eca841e7500c7551563ff777164dd88f0140cb2f --- /dev/null +++ b/models/week0417/json/single_scale_grouping_ft.json @@ -0,0 +1,25 @@ +{ + "exp_name": "04-18/", + "batch_size": 1, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model", + "save_freq": 1000, + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 0, + "add_clustering_epoch": 0, + "add_texture_epoch": 0, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "nepochs": 20, + "pretrained_path": "/home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth" +} diff --git a/models/week0417/json/single_scale_grouping_ft_learnable_grouping.json b/models/week0417/json/single_scale_grouping_ft_learnable_grouping.json new file mode 100644 index 0000000000000000000000000000000000000000..1a8bf9629c56e3f5355f815af19dd9eb3294182a --- /dev/null +++ b/models/week0417/json/single_scale_grouping_ft_learnable_grouping.json @@ -0,0 +1,24 @@ +{ + "exp_name": "04-17/single_scale_grouping_ft_learnable_grouping", + "batch_size": 1, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model_no_rec_no_connectivity_learnable_grouping", + "save_freq": 1000, + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 0, + "add_clustering_epoch": 0, + "add_texture_epoch": 0, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "pretrained_path": "/home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth" +} diff --git a/models/week0417/json/single_scale_grouping_ft_no_connectivity.json b/models/week0417/json/single_scale_grouping_ft_no_connectivity.json new file mode 100644 index 0000000000000000000000000000000000000000..fa9edba6f84c7deb8eddbc5cbf04d48d679bdd43 --- /dev/null +++ b/models/week0417/json/single_scale_grouping_ft_no_connectivity.json @@ -0,0 +1,24 @@ +{ + "exp_name": "04-17/single_scale_grouping_ft_no_connectivity", + "batch_size": 1, + "lr": 5e-5, + "hidden_dim": 256, + "crop_size": 224, + "model_name": "model_no_rec_no_connectivity", + "save_freq": 1000, + "lambda_L1": 1, + "lambda_GAN": 1, + "num_D": 2, + "n_layers_D": 3, + "dataset": "dataset", + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 0, + "add_clustering_epoch": 0, + "add_texture_epoch": 0, + "lambda_style_loss": 1.0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": 1, + "spatial_code_dim": 32, + "pretrained_path": "/home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth" +} diff --git a/models/week0417/loss.py b/models/week0417/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1c449e0dc4e3d658f22ed3327601910eb57aaaae --- /dev/null +++ b/models/week0417/loss.py @@ -0,0 +1,481 @@ +""" +07-21 +StyleLoss to encourage style statistics to be consistent within each cluster. +""" +import torch +import torchvision +import torch.nn as nn +import torch.nn.functional as F + +# VGG architecter, used for the perceptual loss using a pretrained VGG network +class VGG19(torch.nn.Module): + def __init__(self, requires_grad=False, device = torch.device(f'cuda:0')): + super().__init__() + vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + def forward(self, X): + #X = self.normalization(X) + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + +# create a module to normalize input image so we can easily put it in a +# nn.Sequential +class Normalization(nn.Module): + def __init__(self, mean, std): + super(Normalization, self).__init__() + # .view the mean and std to make them [C x 1 x 1] so that they can + # directly work with image Tensor of shape [B x C x H x W]. + # B is batch size. C is number of channels. H is height and W is width. + self.mean = torch.tensor(mean).view(-1, 1, 1) + self.std = torch.tensor(std).view(-1, 1, 1) + + def forward(self, img): + # normalize img + return (img - self.mean) / self.std + +class GramMatrix(nn.Module): + def forward(self,input): + b, c, h, w = input.size() + f = input.view(b,c,h*w) # bxcx(hxw) + # torch.bmm(batch1, batch2, out=None) # + # batch1: bxmxp, batch2: bxpxn -> bxmxn # + G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + return G.div_(c*h*w) + +class StyleLoss(nn.Module): + """ + Version 1. Compare mean and variance cluster-wise. + """ + def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'): + super().__init__() + self.vgg = VGG19() + self.style_layers = [] + for style_layer in style_layers.split(','): + self.style_layers.append(int(style_layer[-1]) - 1) + self.style_mode = style_mode + + cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + def forward(self, pred, gt): + """ + INPUTS: + - pred: (B, 3, H, W) + - gt: (B, 3, H, W) + - seg: (B, H, W) + """ + # extract features for images + B, _, H, W = pred.shape + pred = self.normalization(pred) + gt = self.normalization(gt) + pred_feats = self.vgg(pred) + gt_feats = self.vgg(gt) + loss = 0 + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + pred_gram = GramMatrix()(pred_feat) + gt_gram = GramMatrix()(gt_feat) + loss += torch.sum((pred_gram - gt_gram) ** 2) / B + return loss + +class styleLossMask(nn.Module): + """ + Version 1. Compare mean and variance cluster-wise. + """ + def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), style_mode = 'gram'): + super().__init__() + self.vgg = VGG19() + self.style_layers = [] + for style_layer in style_layers.split(','): + self.style_layers.append(int(style_layer[-1]) - 1) + self.style_mode = style_mode + + #cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + #cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + #self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + def forward(self, input, target, mask): + B, _, H, W = input.shape + #pred = self.normalization(input) + #target = self.normalization(target) + pred_feats = self.vgg(input) + gt_feats = self.vgg(target) + + + loss = 0 + mb, mc, mh, mw = mask.shape + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + + ib,ic,ih,iw = pred_feat.size() + iF = pred_feat.view(ib,ic,-1) + tb,tc,th,tw = gt_feat.size() + tF = gt_feat.view(tb,tc,-1) + + for i in range(mb): + # resize mask to have the same size of the feature + maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest') + mask_flat_i = maski.view(mc, -1) + + maskt = F.interpolate(mask[i:i+1], size = (th, tw), mode = 'nearest') + mask_flat_t = maskt.view(mc, -1) + for j in range(mc): + # get features for each part + idx = torch.nonzero(mask_flat_i[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + ipart = torch.index_select(iF, 2, idx) + + idx = torch.nonzero(mask_flat_t[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + tpart = torch.index_select(tF, 2, idx) + + iMean = torch.mean(ipart,dim=2) + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + tMean = torch.mean(tpart,dim=2) + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + loss_j = nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram) + loss += loss_j + return loss/tb + +# Perceptual loss that uses a pretrained VGG network +class VGGLoss(nn.Module): + def __init__(self, weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0], device = torch.device(f'cuda:0')): + super(VGGLoss, self).__init__() + self.vgg = VGG19(device = device) + self.criterion = nn.L1Loss() + self.weights = weights + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + +class styleLossMaskv2(nn.Module): + def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')): + super().__init__() + self.vgg = VGG19(device = device) + self.style_layers = [] + for style_layer in style_layers.split(','): + self.style_layers.append(int(style_layer[-1]) - 1) + cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + def forward(self, input, target, mask_input, mask_target): + B, _, H, W = input.shape + input = self.normalization(input) + target = self.normalization(target) + pred_feats = self.vgg(input) + gt_feats = self.vgg(target) + + + loss = 0 + mb, mc, mh, mw = mask_input.shape + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + + ib,ic,ih,iw = pred_feat.size() + iF = pred_feat.view(ib,ic,-1) + tb,tc,th,tw = gt_feat.size() + tF = gt_feat.view(tb,tc,-1) + + for i in range(mb): + # resize mask to have the same size of the feature + maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') + mask_flat_i = maski.view(mc, -1) + + maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') + mask_flat_t = maskt.view(mc, -1) + for j in range(mc): + # get features for each part + idx = torch.nonzero(mask_flat_i[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + ipart = torch.index_select(iF[i:i+1], 2, idx) + + idx = torch.nonzero(mask_flat_t[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + tpart = torch.index_select(tF[i:i+1], 2, idx) + + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + loss += torch.sum((iGram - tGram) ** 2) + return loss/tb + +class styleLossMaskv3(nn.Module): + def __init__(self, style_layers = 'relu1, relu2, relu3, relu4, relu5', device = torch.device(f'cuda:0')): + super().__init__() + self.vgg = VGG19(device = device) + self.style_layers = [] + for style_layer in style_layers.split(','): + self.style_layers.append(int(style_layer[-1]) - 1) + cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + def forward_img_img(self, input, target, mask_input, mask_target): + B, _, H, W = input.shape + input = self.normalization(input) + target = self.normalization(target) + pred_feats = self.vgg(input) + gt_feats = self.vgg(target) + + loss = 0 + mb, mc, mh, mw = mask_input.shape + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + + ib,ic,ih,iw = pred_feat.size() + iF = pred_feat.view(ib,ic,-1) + tb,tc,th,tw = gt_feat.size() + tF = gt_feat.view(tb,tc,-1) + + for i in range(mb): + # resize mask to have the same size of the feature + maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') + mask_flat_i = maski.view(mc, -1) + + maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') + mask_flat_t = maskt.view(mc, -1) + for j in range(mc): + # get features for each part + idx = torch.nonzero(mask_flat_i[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + ipart = torch.index_select(iF[i:i+1], 2, idx) + + idx = torch.nonzero(mask_flat_t[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + tpart = torch.index_select(tF[i:i+1], 2, idx) + + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + #loss += torch.sum((iGram - tGram) ** 2) + loss += F.mse_loss(iGram, tGram) + #return loss/tb + return loss * 100000 / tb + + def forward_patch_img(self, input, target, mask_target): + input = self.normalization(input) + target = self.normalization(target) + pred_feats = self.vgg(input) + gt_feats = self.vgg(target) + patch_num = input.shape[0] // target.shape[0] + + loss = 0 + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + + ib,ic,ih,iw = pred_feat.size() + iF = pred_feat.view(ib,ic,-1) + tb,tc,th,tw = gt_feat.size() + tF = gt_feat.view(tb,tc,-1) + + for i in range(tb): + # resize mask to have the same size of the feature + maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') + mask_flat_t = maskt.view(-1) + + idx = torch.nonzero(mask_flat_t).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + tpart = torch.index_select(tF[i:i+1], 2, idx) + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + ipart = iF[i * patch_num: (i + 1) * patch_num] + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + #loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2) + loss += F.mse_loss(iGram, tGram.repeat(patch_num, 1, 1)) + return loss/ib * 100000 + +class KLDLoss(nn.Module): + def forward(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + +class LPIPSorGramMatch(nn.Module): + """ + Version 1. Compare mean and variance cluster-wise. + """ + def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0'), + weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0],): + super().__init__() + self.vgg = VGG19(device = device) + self.style_layers = [] + for style_layer in style_layers.split(','): + self.style_layers.append(int(style_layer[-1]) - 1) + + cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + self.criterion = nn.L1Loss() + self.weights = weights + + def forward(self, pred, gt, mode = 'lpips'): + """ + INPUTS: + - pred: (B, 3, H, W) + - gt: (B, 3, H, W) + - seg: (B, H, W) + """ + # extract features for images + B, _, H, W = pred.shape + pred = self.normalization(pred) + gt = self.normalization(gt) + pred_feats = self.vgg(pred) + gt_feats = self.vgg(gt) + + if mode == 'lpips': + lpips_loss = 0 + for i in range(len(pred_feats)): + lpips_loss += self.weights[i] * self.criterion(pred_feats[i], gt_feats[i].detach()) + return lpips_loss + elif mode == 'gram_match': + gram_match_loss = 0 + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + pred_gram = GramMatrix()(pred_feat) + gt_gram = GramMatrix()(gt_feat) + gram_match_loss += torch.sum((pred_gram - gt_gram) ** 2) / B + return gram_match_loss + else: + raise ValueError("Only computes lpips or gram match loss.") + +class styleLossMaskv4(nn.Module): + def __init__(self, style_layers = 'relu3, relu4, relu5', device = torch.device(f'cuda:0')): + super().__init__() + self.vgg = VGG19(device = device) + self.style_layers = [] + for style_layer in style_layers.split(','): + self.style_layers.append(int(style_layer[-1]) - 1) + cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.normalization = Normalization(cnn_normalization_mean, cnn_normalization_std) + + def forward_img_img(self, input, target, mask_input, mask_target): + B, _, H, W = input.shape + input = self.normalization(input) + target = self.normalization(target) + pred_feats = self.vgg(input) + gt_feats = self.vgg(target) + + + loss = 0 + mb, mc, mh, mw = mask_input.shape + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + + ib,ic,ih,iw = pred_feat.size() + iF = pred_feat.view(ib,ic,-1) + tb,tc,th,tw = gt_feat.size() + tF = gt_feat.view(tb,tc,-1) + + for i in range(mb): + # resize mask to have the same size of the feature + maski = F.interpolate(mask_input[i:i+1], size = (ih, iw), mode = 'nearest') + mask_flat_i = maski.view(mc, -1) + + maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') + mask_flat_t = maskt.view(mc, -1) + for j in range(mc): + # get features for each part + idx = torch.nonzero(mask_flat_i[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + ipart = torch.index_select(iF[i:i+1], 2, idx) + + idx = torch.nonzero(mask_flat_t[j]).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + tpart = torch.index_select(tF[i:i+1], 2, idx) + + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + iMean = torch.mean(ipart, dim=2) + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + tMean = torch.mean(tpart, dim=2) + + loss += torch.sum((iGram - tGram) ** 2) + torch.sum((iMean - tMean) ** 2) * 0.01 + return loss/tb + + def forward_patch_img(self, input, target, mask_target): + input = self.normalization(input) + target = self.normalization(target) + pred_feats = self.vgg(input) + gt_feats = self.vgg(target) + patch_num = input.shape[0] // target.shape[0] + + loss = 0 + for style_layer in self.style_layers: + pred_feat = pred_feats[style_layer] + gt_feat = gt_feats[style_layer] + + ib,ic,ih,iw = pred_feat.size() + iF = pred_feat.view(ib,ic,-1) + tb,tc,th,tw = gt_feat.size() + tF = gt_feat.view(tb,tc,-1) + + for i in range(tb): + # resize mask to have the same size of the feature + maskt = F.interpolate(mask_target[i:i+1], size = (th, tw), mode = 'nearest') + mask_flat_t = maskt.view(-1) + + idx = torch.nonzero(mask_flat_t).squeeze() + if len(idx.shape) == 0 or idx.shape[0] == 0: + continue + tpart = torch.index_select(tF[i:i+1], 2, idx) + tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tpart.shape[1] * tpart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + tMean = torch.mean(tpart, dim=2) + + ipart = iF[i * patch_num: (i + 1) * patch_num] + iMean = torch.mean(ipart, dim=2) + iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ipart.shape[1] * ipart.shape[2]) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc + + loss += torch.sum((iGram - tGram.repeat(patch_num, 1, 1)) ** 2) + loss += torch.sum((iMean - tMean.repeat(patch_num, 1)) ** 2) * 0.01 + return loss/ib diff --git a/models/week0417/model.py b/models/week0417/model.py new file mode 100644 index 0000000000000000000000000000000000000000..23c2c8045979894495158fb5406cbfd240094c79 --- /dev/null +++ b/models/week0417/model.py @@ -0,0 +1,204 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF + +from .taming_blocks import Encoder +from .nnutils import SPADEResnetBlock, get_edges, initWave + +from libs.nnutils import poolfeat, upfeat +from libs.utils import label2one_hot_torch + +from swapae.models.networks.stylegan2_layers import ConvLayer +from torch_geometric.nn import GCNConv +from torch_geometric.utils import softmax +from .loss import styleLossMaskv3 + +class GCN(nn.Module): + def __init__(self, n_cluster, temperature = 1, add_self_loops = True, hidden_dim = 256): + super().__init__() + self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops) + self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops) + self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1)) + self.temperature = temperature + + def compute_edge_score_softmax(self, raw_edge_score, edge_index, num_nodes): + return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes) + + def compute_edge_weight(self, node_feature, edge_index): + src_feat = torch.gather(node_feature, 0, edge_index[0].unsqueeze(1).repeat(1, node_feature.shape[1])) + tgt_feat = torch.gather(node_feature, 0, edge_index[1].unsqueeze(1).repeat(1, node_feature.shape[1])) + raw_edge_weight = nn.CosineSimilarity(dim=1, eps=1e-6)(src_feat, tgt_feat) + edge_weight = self.compute_edge_score_softmax(raw_edge_weight, edge_index, node_feature.shape[0]) + return raw_edge_weight.squeeze(), edge_weight.squeeze() + + def forward(self, sp_code, slic, clustering = False): + edges, aff = get_edges(torch.argmax(slic, dim = 1).unsqueeze(1), sp_code.shape[1]) + prop_code = [] + sp_assign = [] + edge_weights = [] + conv_feats = [] + for i in range(sp_code.shape[0]): + # compute edge weight + edge_index = edges[i] + raw_edge_weight, edge_weight = self.compute_edge_weight(sp_code[i], edge_index) + feat = self.gcnconv1(sp_code[i], edge_index, edge_weight = edge_weight) + raw_edge_weight, edge_weight = self.compute_edge_weight(feat, edge_index) + edge_weights.append(raw_edge_weight) + feat = F.leaky_relu(feat, 0.2) + feat = self.gcnconv2(feat, edge_index, edge_weight = edge_weight) + + # maybe clustering + conv_feat = upfeat(feat, slic[i:i+1]) + conv_feats.append(conv_feat) + if not clustering: + feat = conv_feat + pred_mask = slic[i:i+1] + else: + pred_mask = self.pool1(conv_feat) + # enforce pixels belong to the same superpixel to have same grouping label + pred_mask = upfeat(poolfeat(pred_mask, slic[i:i+1]), slic[i:i+1]) + s_ = F.softmax(pred_mask * self.temperature, dim = 1) + + # compute texture code w.r.t grouping + pool_feat = poolfeat(conv_feat, s_, avg = True) + # hard upsampling + #hard_s_ = label2one_hot_torch(torch.argmax(s_, dim = 1).unsqueeze(1), C = s_.shape[1]) + feat = upfeat(pool_feat, s_) + #feat = upfeat(pool_feat, hard_s_) + + prop_code.append(feat) + sp_assign.append(pred_mask) + prop_code = torch.cat(prop_code) + conv_feats = torch.cat(conv_feats) + return prop_code, torch.cat(sp_assign), conv_feats + +class SPADEGenerator(nn.Module): + def __init__(self, in_dim, hidden_dim): + super().__init__() + nf = hidden_dim // 16 + + self.head_0 = SPADEResnetBlock(in_dim, 16 * nf) + + self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf) + self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf) + + self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf) + self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf) + self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf) + self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf) + + final_nc = nf + + self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) + + self.up = nn.Upsample(scale_factor=2) + + def forward(self, sine_wave, texon): + + x = self.head_0(sine_wave, texon) + + x = self.up(x) + x = self.G_middle_0(x, texon) + x = self.G_middle_1(x, texon) + + x = self.up(x) + x = self.up_0(x, texon) + x = self.up(x) + x = self.up_1(x, texon) + #x = self.up(x) + x = self.up_2(x, texon) + #x = self.up(x) + x = self.up_3(x, texon) + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + return x + +class Waver(nn.Module): + def __init__(self, tex_code_dim, zPeriodic): + super(Waver, self).__init__() + K = tex_code_dim + layers = [nn.Conv2d(tex_code_dim, K, 1)] + layers += [nn.ReLU(True)] + layers += [nn.Conv2d(K, 2 * zPeriodic, 1)] + self.learnedWN = nn.Sequential(*layers) + self.waveNumbers = initWave(zPeriodic) + + def forward(self, GLZ=None): + return (self.waveNumbers.to(GLZ.device) + self.learnedWN(GLZ)) + +class AE(nn.Module): + def __init__(self, args, **ignore_kwargs): + super(AE, self).__init__() + + # encoder & decoder + self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[], + in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False) + if args.dec_input_mode == 'sine_wave_noise': + self.G = SPADEGenerator(args.spatial_code_dim * 2, args.hidden_dim) + else: + self.G = SPADEGenerator(args.spatial_code_dim, args.hidden_dim) + + self.add_module( + "ToTexCode", + nn.Sequential( + ConvLayer(args.hidden_dim, args.hidden_dim, kernel_size=3, activate=True, bias=True), + ConvLayer(args.hidden_dim, args.tex_code_dim, kernel_size=3, activate=True, bias=True), + ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False) + ) + ) + self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, add_self_loops = (args.add_self_loops == 1), hidden_dim = args.hidden_dim) + + self.add_gcn_epoch = args.add_gcn_epoch + self.add_clustering_epoch = args.add_clustering_epoch + self.add_texture_epoch = args.add_texture_epoch + + self.patch_size = args.patch_size + self.sine_wave_dim = args.spatial_code_dim + + # inpainting network + self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim) + self.dec_input_mode = args.dec_input_mode + self.style_loss = styleLossMaskv3(device = args.device) + + if args.sine_weight: + if args.dec_input_mode == 'sine_wave_noise': + self.add_module( + "ChannelWeight", + nn.Sequential( + ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True), + ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True), + ConvLayer(args.hidden_dim//4, args.spatial_code_dim*2, kernel_size=1, activate=False, bias=False, downsample=True))) + else: + self.add_module( + "ChannelWeight", + nn.Sequential( + ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True), + ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True), + ConvLayer(args.hidden_dim//4, args.spatial_code_dim, kernel_size=1, activate=False, bias=False, downsample=True))) + + def get_sine_wave(self, GL, offset_mode = 'random'): + img_size = GL.shape[-1] // 8 + GL = F.interpolate(GL, size = (img_size, img_size), mode = 'nearest') + xv, yv = np.meshgrid(np.arange(img_size), np.arange(img_size),indexing='ij') + c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis]) + c = c.to(GL.device) + # c: 1, 2, 28, 28 + c = c.repeat(GL.shape[0], self.sine_wave_dim, 1, 1) + # c: 1, 64, 28, 28 + period = self.learnedWN(GL) + # period: 1, 64, 28, 28 + raw = period * c + if offset_mode == 'random': + offset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28 + offset = offset.repeat(1, 1, img_size, img_size) + wave = torch.sin(raw[:, ::2] + raw[:, 1::2] + offset) + elif offset_mode == 'rec': + wave = torch.sin(raw[:, ::2] + raw[:, 1::2]) + return wave + + def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None): + return diff --git a/models/week0417/model_utils.py b/models/week0417/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c761cd7c08b4e1682e0fa1301a07e509213d6973 --- /dev/null +++ b/models/week0417/model_utils.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + +## *************************** my functions **************************** + +def predict_param(in_planes, channel=3): + return nn.Conv2d(in_planes, channel, kernel_size=3, stride=1, padding=1, bias=True) + +def predict_mask(in_planes, channel=9): + return nn.Conv2d(in_planes, channel, kernel_size=3, stride=1, padding=1, bias=True) + +def predict_feat(in_planes, channel=20, stride=1): + return nn.Conv2d(in_planes, channel, kernel_size=3, stride=stride, padding=1, bias=True) + +def predict_prob(in_planes, channel=9): + return nn.Sequential( + nn.Conv2d(in_planes, channel, kernel_size=3, stride=1, padding=1, bias=True), + nn.Softmax(1) + ) +#*********************************************************************** + +def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): + if batchNorm: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.1) + ) + else: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), + nn.LeakyReLU(0.1) + ) + + +def deconv(in_planes, out_planes): + return nn.Sequential( + nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.LeakyReLU(0.1) + ) diff --git a/models/week0417/nnutils.py b/models/week0417/nnutils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4cb2479397321e21039dd4bea0906a2deb23c08 --- /dev/null +++ b/models/week0417/nnutils.py @@ -0,0 +1,205 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from skimage.segmentation._slic import _enforce_label_connectivity_cython + +def initWave(nPeriodic): + buf = [] + for i in range(nPeriodic // 4+1): + v = 0.5 + i / float(nPeriodic//4+1e-10) + buf += [0, v, v, 0] + buf += [0, -v, v, 0] #so from other quadrants as well.. + buf = buf[:2*nPeriodic] + awave = np.array(buf, dtype=np.float32) * np.pi + awave = torch.FloatTensor(awave).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) + return awave + +class SPADEGenerator(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + nf = hidden_dim // 16 + + self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf) + + self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf) + self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf) + + self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf) + self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf) + self.up_2 = SPADEResnetBlock(4 * nf, nf) + #self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf) + + final_nc = nf + + self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) + + self.up = nn.Upsample(scale_factor=2) + + + def forward(self, x, input): + seg = input + + x = self.head_0(x, seg) + + x = self.up(x) + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + + x = self.up(x) + x = self.up_0(x, seg) + x = self.up(x) + x = self.up_1(x, seg) + x = self.up(x) + x = self.up_2(x, seg) + #x = self.up(x) + #x = self.up_3(x, seg) + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + return x + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + ks = 3 + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + + # The dimension of the intermediate embedding space. Yes, hardcoded. + nhidden = 128 + + pw = ks // 2 + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.ReLU() + ) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, x, segmap): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + segmap = F.interpolate(segmap, size=x.size()[2:], mode='bilinear', align_corners = False) + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + out = normalized * (1 + gamma) + beta + + return out + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + # define normalization layers + self.norm_0 = SPADE(fin, 256) + self.norm_1 = SPADE(fmiddle, 256) + if self.learned_shortcut: + self.norm_s = SPADE(fin, 256) + + # note the resnet block with SPADE also takes in |seg|, + # the semantic segmentation map as input + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + + dx = self.conv_0(self.actvn(self.norm_0(x, seg))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) + + out = x_s + dx + + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + +def get_edges(sp_label, sp_num): + # This function returns a (hw) * (hw) matrix N. + # If Nij = 1, then superpixel i and j are neighbors + # Otherwise Nij = 0. + top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :] + left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:] + top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:] + top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1] + n_affs = [] + edge_indices = [] + for i in range(sp_label.shape[0]): + # change to torch.ones below to include self-loop in graph + n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).to(sp_label.device) + # top/bottom + top_i = top[i].squeeze() + x, y = torch.nonzero(top_i, as_tuple = True) + sp1 = sp_label[i, :, x, y].squeeze().long() + sp2 = sp_label[i, :, x+1, y].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + # left/right + left_i = left[i].squeeze() + try: + x, y = torch.nonzero(left_i, as_tuple = True) + except: + import pdb; pdb.set_trace() + sp1 = sp_label[i, :, x, y].squeeze().long() + sp2 = sp_label[i, :, x, y+1].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + # top left + top_left_i = top_left[i].squeeze() + x, y = torch.nonzero(top_left_i, as_tuple = True) + sp1 = sp_label[i, :, x, y].squeeze().long() + sp2 = sp_label[i, :, x+1, y+1].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + # top right + top_right_i = top_right[i].squeeze() + x, y = torch.nonzero(top_right_i, as_tuple = True) + sp1 = sp_label[i, :, x, y+1].squeeze().long() + sp2 = sp_label[i, :, x+1, y].squeeze().long() + n_aff[:, sp1, sp2] = 1 + n_aff[:, sp2, sp1] = 1 + + n_affs.append(n_aff) + edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True)) + edge_indices.append(edge_index.to(sp_label.device)) + return edge_indices, torch.cat(n_affs) + +def enforce_connectivity(segs, H, W, sp_num = 196, min_size = None, max_size = None): + rets = [] + for i in range(segs.shape[0]): + seg = segs[i] + seg = seg.squeeze().cpu().numpy() + + segment_size = H * W / sp_num + if min_size is None: + min_size = int(0.1 * segment_size) + if max_size is None: + max_size = int(1000.0 * segment_size) + seg = _enforce_label_connectivity_cython(seg[None], min_size, max_size)[0] + seg = torch.from_numpy(seg).unsqueeze(0).unsqueeze(0) + rets.append(seg) + return torch.cat(rets) diff --git a/models/week0417/taming_blocks.py b/models/week0417/taming_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2e66fbad331894b925020f6b157defc64d644b82 --- /dev/null +++ b/models/week0417/taming_blocks.py @@ -0,0 +1,778 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h, hs + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, with_mid=True, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.with_mid = with_mid + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + if self.with_mid: + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + if self.with_mid: + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..541dfb345e7d76cac449a50d12e86e5135809f94 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +torch==1.7.1 +torch-geometric==1.7.2 +torch-scatter==2.0.8 +torch-sparse==0.6.11 +torchvision==0.8.2 +streamlit==1.8.1 +st-clickable-images==0.0.3 +func-timeout==4.3.5 +scikit-learn==1.0.2 +imageio==2.17.0 diff --git a/swapae/data/__init__.py b/swapae/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db43dab4e9925a9e273a495eabf8521a5a48c47b --- /dev/null +++ b/swapae/data/__init__.py @@ -0,0 +1,129 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import importlib +import torch.utils.data +from swapae.data.base_dataset import BaseDataset +import swapae.util as util + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "swapae.data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + return ConfigurableDataLoader(opt) + + +class DataPrefetcher(): + def __init__(self, dataset): + self.dataset = dataset + self.stream = torch.cuda.Stream() + self.preload() + + def preload(self): + try: + self.next_input = next(self.dataset) + except StopIteration: + self.next_input = None + return + + with torch.cuda.stream(self.stream): + self.next_input = self.next_input.cuda(non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + input = self.next_input + self.preload() + return input + + def __iter__(self): + return self + + def __len__(self): + return len(self.dataset) + + +class ConfigurableDataLoader(): + def __init__(self, opt): + self.opt = opt + self.initialize(opt.phase) + + def initialize(self, phase): + opt = self.opt + self.phase = phase + if hasattr(self, "dataloader"): + del self.dataloader + dataset_class = find_dataset_using_name(opt.dataset_mode) + dataset = dataset_class(util.copyconf(opt, phase=phase, isTrain=phase == "train")) + shuffle = phase == "train" if opt.shuffle_dataset is None else opt.shuffle_dataset == "true" + print("dataset [%s] of size %d was created. shuffled=%s" % (type(dataset).__name__, len(dataset), shuffle)) + #dataset = DataPrefetcher(dataset) + self.opt = opt + self.dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=opt.batch_size, + shuffle=shuffle, + num_workers=int(opt.num_gpus), + drop_last=phase == "train", + ) + #self.dataloader = dataset + self.dataloader_iterator = iter(self.dataloader) + self.repeat = phase == "train" + self.length = len(dataset) + self.underlying_dataset = dataset + + def set_phase(self, target_phase): + if self.phase != target_phase: + self.initialize(target_phase) + + def __iter__(self): + self.dataloader_iterator = iter(self.dataloader) + return self + + def __len__(self): + return self.length + + def __next__(self): + try: + return next(self.dataloader_iterator) + except StopIteration: + if self.repeat: + self.dataloader_iterator = iter(self.dataloader) + return next(self.dataloader_iterator) + else: + raise StopIteration diff --git a/swapae/data/base_dataset.py b/swapae/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..93441f5c2cd9dbc6fb2a2074d0ef895ef70c0d06 --- /dev/null +++ b/swapae/data/base_dataset.py @@ -0,0 +1,250 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + self.root = opt.dataroot + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + def set_phase(self, phase): + assert phase in ["train", "test", "val"] + self.current_phase = phase + pass + + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.preprocess == 'resize_and_crop': + new_h = new_w = opt.load_size + elif opt.preprocess == 'scale_width_and_crop': + new_w = opt.load_size + new_h = opt.load_size * h // w + + x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) + y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) + + return {'crop_pos': (x, y)} + + +def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + if 'fixsize' in opt.preprocess: + transform_list.append(transforms.Resize((opt.crop_size, opt.load_size), method)) + if 'resize' in opt.preprocess: + osize = [opt.load_size, opt.load_size] + if "gta2cityscapes" in opt.dataroot: + osize[0] = opt.load_size // 2 + transform_list.append(transforms.Resize(osize, method)) + elif 'scale_width' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) + elif 'scale_shortside' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method))) + elif 'scale_longside' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __scale_longside(img, opt.load_size, opt.crop_size, method))) + + + #if 'rotate' in opt.preprocess: + # transform_list.append(transforms.RandomRotation(180, resample=Image.BILINEAR)) + + if 'zoom' in opt.preprocess: + if params is None: + transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method))) + else: + transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"]))) + + if 'centercrop'in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __centercrop(img))) + elif 'crop' in opt.preprocess: + if params is None or 'crop_pos' not in params: + transform_list.append(transforms.RandomCrop(opt.crop_size, padding=opt.preprocess_crop_padding)) + else: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) + + if 'patch' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size))) + + if 'trim' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size))) + + #if opt.preprocess == 'none': + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=16, method=method))) + + random_flip = opt.isTrain and (not opt.no_flip) + if random_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + #elif 'flip' in params: + # transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + if convert: + transform_list += [transforms.ToTensor()] + if grayscale: + transform_list += [transforms.Normalize((0.5,), (0.5,))] + else: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if h == oh and w == ow: + return img + + #__print_size_warning(ow, oh, w, h) + return img.resize((w, h), method) + + +def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None): + iw, ih = img.size + if factor is None: + zoom_level = np.random.uniform(crop_width / iw, 1.0, size=[2]) + else: + zoom_level = (factor[0], factor[1]) + zoomw = max(crop_width, iw * zoom_level[0]) + zoomh = max(crop_width, ih * zoom_level[1]) + img = img.resize((int(round(zoomw)), int(round(zoomh))), method) + return img + + +def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC): + ow, oh = img.size + shortside = min(ow, oh) + scale = target_width / shortside + return img.resize((round(ow * scale), round(oh * scale)), method) + +def __centercrop(img): + ow, oh = img.size + s = min(ow, oh) + return img.crop(((ow - s) // 2, (oh - s) // 2, (ow + s) // 2, (oh + s) // 2)) + +def __scale_longside(img, target_width, crop_width, method=Image.BICUBIC): + ow, oh = img.size + longside = max(ow, oh) + scale = target_width / longside + return img.resize((round(ow * scale), round(oh * scale)), method) + +def __trim(img, trim_width): + ow, oh = img.size + if ow > trim_width: + xstart = np.random.randint(ow - trim_width) + xend = xstart + trim_width + else: + xstart = 0 + xend = ow + if oh > trim_width: + ystart = np.random.randint(oh - trim_width) + yend = ystart + trim_width + else: + ystart = 0 + yend = oh + return img.crop((xstart, ystart, xend, yend)) + + +def __scale_width(img, target_width, crop_width, method=Image.BICUBIC): + ow, oh = img.size + if ow == target_width and oh >= crop_width: + return img + w = target_width + #h = int(max(target_width * oh / ow, crop_width)) + h = int(target_width * oh / ow) + return img.resize((w, h), method) + + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + + +def __patch(img, index, size): + ow, oh = img.size + nw, nh = ow // size, oh // size + roomx = ow - nw * size + roomy = oh - nh * size + startx = np.random.randint(int(roomx) + 1) + starty = np.random.randint(int(roomy) + 1) + + index = index % (nw * nh) + ix = index // nh + iy = index % nh + gridx = startx + ix * size + gridy = starty + iy * size + return img.crop((gridx, gridy, gridx + size, gridy + size)) + + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img + + +def __print_size_warning(ow, oh, w, h): + """Print warning information about image size(only print once)""" + if not hasattr(__print_size_warning, 'has_printed'): + print("The image size needs to be a multiple of 4. " + "The loaded image size was (%d, %d), so it was adjusted to " + "(%d, %d). This adjustment will be done to all images " + "whose sizes are not multiples of 4" % (ow, oh, w, h)) + __print_size_warning.has_printed = True diff --git a/swapae/data/cifar100_dataset.py b/swapae/data/cifar100_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b7914df7135a82a8e4e3b9c6e713a7780637bdd8 --- /dev/null +++ b/swapae/data/cifar100_dataset.py @@ -0,0 +1,58 @@ +import random +import numpy as np +import os.path +from swapae.data.base_dataset import BaseDataset, get_transform +import torchvision + + +class CIFAR100Dataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.set_defaults(load_size=32, crop_size=32, preprocess_crop_padding=0, + preprocess='crop', num_classes=100, use_class_labels=True) + opt, _ = parser.parse_known_args() + assert opt.preprocess == 'crop' and opt.load_size == 32 and opt.crop_size == 32 + return parser + + def __init__(self, opt): + self.opt = opt + self.torch_dataset = torchvision.datasets.CIFAR100( + opt.dataroot, train=opt.isTrain, download=True + ) + self.transform = get_transform(self.opt, grayscale=False) + self.class_list = self.create_class_list() + + def create_class_list(self): + cache_path = os.path.join(self.opt.dataroot, "%s_classlist.npy" % self.opt.phase) + if os.path.exists(cache_path): + cache = np.load(cache_path) + classlist = {i: [] for i in range(100)} + for i, c in enumerate(cache): + classlist[c].append(i) + return classlist + + print("creating cache list of classes...") + classes = np.zeros((len(self.torch_dataset)), dtype=int) + for i in range(len(self.torch_dataset)): + _, class_id = self.torch_dataset[i] + classes[i] = class_id + if i % 100 == 0: + print("%d/%d\r" % (i, len(self.torch_dataset)), end="", flush=True) + np.save(cache_path, classes) + print("cache saved at %s" % cache_path) + return self.create_class_list() + + def __getitem__(self, index): + index = index % len(self.torch_dataset) + image, class_id = self.torch_dataset[index] + + another_image_index = random.choice(self.class_list[class_id]) + another_image, another_class_id = self.torch_dataset[another_image_index] + assert class_id == another_class_id + return {"real_A": self.transform(image), + "real_B": self.transform(another_image), + "class_A": class_id, "class_B": class_id} + + def __len__(self): + return len(self.torch_dataset) + diff --git a/swapae/data/dataset_tools.py b/swapae/data/dataset_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a34c2d9f4b0f7d4735e114ce0e887667e672e2d5 --- /dev/null +++ b/swapae/data/dataset_tools.py @@ -0,0 +1,98 @@ +import PIL +import os +import numpy as np +import argparse +import lmdb +import sys +import cv2 +sys.path.append(".") +from swapae.data.image_folder import make_dataset + + +def create_lmdb_from_images(opt): + paths = sorted(make_dataset(opt.input)) + + output_dir = opt.output + print('Extracting images to "%s"' % output_dir) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + # initialize lmdb + output_dir = opt.output + lmdb_env = lmdb.open(output_dir, map_size=1099511627776) + + with lmdb_env.begin(write=True) as txn: + for idx, image_path in enumerate(paths): + if idx % 10 == 0: + print('%d\r' % idx, end='', flush=True) + img = PIL.Image.open(image_path).convert('RGB') + img = np.asarray(img) + img = cv2.imencode('.png', img)[1].tostring() + txn.put(image_path.encode('ascii'), img) + + +def create_lmdb_from_tfrecords(opt): + # initialize tensorflow + assert opt.stylegan_codebase_path is not None + sys.path.append(opt.stylegan_codebase_path) + import dnnlib.tflib as tflib + from training import dataset + import tensorflow as tf + tfrecord_dir = opt.input + print('Loading dataset "%s"' % tfrecord_dir) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + + # initialize lmdb + output_dir = opt.output + lmdb_env = lmdb.open(output_dir, map_size=1099511627776) + + print('Extracting images to "%s"' % output_dir) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + idx = 0 + with lmdb_env.begin(write=True) as txn: + while True: + idx += 1 + if idx % 10 == 0: + print('%d\r' % idx, end='', flush=True) + try: + images, _labels = dset.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + break + if images.shape[1] == 1: + img = PIL.Image.fromarray(images[0][0], 'L').convert('RGB') + else: + img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB') + img = np.asarray(img) + img = cv2.imencode('.png', img)[1].tostring() + imagekey = "%08d" % idx + txn.put(imagekey.encode('ascii'), img) + + + + + +if __name__ == "__main__": + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + parser = argparse.ArgumentParser() + + parser.add_argument("mode", + choices=("create_lmdb_from_images", + "create_lmdb_from_tfrecords", + )) + parser.add_argument("--input", help="input path") + parser.add_argument("--output", help="input path") + parser.add_argument("--stylegan_codebase_path", + help="path to stylegan codebase. Path to git clone https://github.com/NVlabs/stylegan.git") + + opt = parser.parse_args() + + if opt.mode == "create_lmdb_from_images": + create_lmdb_from_images(opt) + elif opt.mode == "create_lmdb_from_tfrecords": + create_lmdb_from_tfrecords(opt) + + print("Finished") + diff --git a/swapae/data/image_folder.py b/swapae/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..c76a08f871fefef6452fcd1ee72e507cf1f0990c --- /dev/null +++ b/swapae/data/image_folder.py @@ -0,0 +1,67 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', '.webp', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/swapae/data/imagefolder_dataset.py b/swapae/data/imagefolder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a5319bd3d4d58e6cc308e1c966dab6b7080c81e1 --- /dev/null +++ b/swapae/data/imagefolder_dataset.py @@ -0,0 +1,33 @@ +import random +from swapae.data.base_dataset import BaseDataset, get_transform +from swapae.data.image_folder import make_dataset +from PIL import Image + + +class ImageFolderDataset(BaseDataset): + def __init__(self, opt): + BaseDataset.__init__(self, opt) + self.dir_A = opt.dataroot + + self.A_paths = sorted(make_dataset(self.dir_A)) + self.A_size = len(self.A_paths) + self.transform_A = get_transform(self.opt, grayscale=False) + + def __getitem__(self, index): + A_path = self.A_paths[index % self.A_size] + return self.getitem_by_path(A_path) + + def getitem_by_path(self, A_path): + try: + A_img = Image.open(A_path).convert('RGB') + except OSError as err: + print(err) + return self.__getitem__(random.randint(0, len(self) - 1)) + + # apply image transformation + A = self.transform_A(A_img) + + return {'real_A': A, 'path_A': A_path} + + def __len__(self): + return self.A_size diff --git a/swapae/data/lmdb_dataset.py b/swapae/data/lmdb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..82f4cc9b0b8d533b3ea92bdb020d5c67ae0d2028 --- /dev/null +++ b/swapae/data/lmdb_dataset.py @@ -0,0 +1,73 @@ +import random +import sys +import os.path +from PIL import Image +from swapae.data.base_dataset import BaseDataset, get_transform +import cv2 +import numpy as np +if sys.version_info[0] == 2: + import cPickle as pickle +else: + import pickle +import torchvision.transforms as transforms + + +class LMDBDataset(BaseDataset): + def __init__(self, opt): + import lmdb + self.opt = opt + write_cache = True + root = opt.dataroot + self.root = os.path.expanduser(root) + self.env = lmdb.open(root, readonly=True, lock=False) + with self.env.begin(write=False) as txn: + self.length = txn.stat()['entries'] + print('lmdb file at %s opened.' % root) + cache_file = os.path.join(root, '_cache_') + if os.path.isfile(cache_file): + self.keys = pickle.load(open(cache_file, "rb")) + elif write_cache: + print('generating keys') + with self.env.begin(write=False) as txn: + self.keys = [key for key, _ in txn.cursor()] + pickle.dump(self.keys, open(cache_file, "wb")) + print('cache file generated at %s' % cache_file) + else: + self.keys = [] + + random.Random(0).shuffle(self.keys) + + self.transform = get_transform(self.opt, grayscale=False) + if "lsun" in self.opt.dataroot.lower(): + print("Seems like a LSUN dataset, so we will apply BGR->RGB conversion") + + + def __getitem__(self, index): + path = self.keys[index] + return self.getitem_by_path(path) + + def getitem_by_path(self, path): + env = self.env + with env.begin(write=False) as txn: + imgbuf = txn.get(path) + try: + img = cv2.imdecode( + np.fromstring(imgbuf, dtype=np.uint8), 1) + except cv2.error as e: + print(path, e) + return self.__getitem__(random.randint(0, self.length - 1)) + if "lsun" in self.opt.dataroot.lower(): + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = Image.fromarray(img) + + return {"real_A": self.transform(img), "path_A": path.decode("utf-8")} + + def set_phase(self, phase): + super().set_phase(phase) + pass + + def __len__(self): + return self.length + + def __repr__(self): + return self.__class__.__name__ + ' (' + self.root + ')' diff --git a/swapae/data/unaligned_dataset.py b/swapae/data/unaligned_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dcde330becebad4a3246c2a081283fd454fdca7b --- /dev/null +++ b/swapae/data/unaligned_dataset.py @@ -0,0 +1,76 @@ +import os.path +from swapae.data.base_dataset import BaseDataset, get_transform +from swapae.data.image_folder import make_dataset +from PIL import Image +import random + + +class UnalignedDataset(BaseDataset): + """ + This dataset class can load unaligned/unpaired datasets. + + It requires two directories to host training images from domain A '/path/to/data/trainA' + and from domain B '/path/to/data/trainB' respectively. + You can train the model with the dataset flag '--dataroot /path/to/data'. + Similarly, you need to prepare two directories: + '/path/to/data/testA' and '/path/to/data/testB' during test time. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' + + if opt.phase == "test" and not os.path.exists(self.dir_A) \ + and os.path.exists(os.path.join(opt.dataroot, "valA")): + self.dir_A = os.path.join(opt.dataroot, "testA") + self.dir_B = os.path.join(opt.dataroot, "testB") + + self.A_paths = sorted(make_dataset(self.dir_A)) # load images from '/path/to/data/trainA' + random.Random(0).shuffle(self.A_paths) + self.B_paths = sorted(make_dataset(self.dir_B)) # load images from '/path/to/data/trainB' + random.Random(0).shuffle(self.B_paths) + self.A_size = len(self.A_paths) # get the size of dataset A + self.B_size = len(self.B_paths) # get the size of dataset B + self.transform_A = get_transform(self.opt, grayscale=False) + self.transform_B = get_transform(self.opt, grayscale=False) + self.B_indices = list(range(self.B_size)) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + A (tensor) -- an image in the input domain + B (tensor) -- its corresponding image in the target domain + A_paths (str) -- image paths + B_paths (str) -- image paths + """ + A_path = self.A_paths[index % self.A_size] # make sure index is within then range + if index == 0 and self.opt.isTrain: + random.shuffle(self.B_indices) + index_B = self.B_indices[index % self.B_size] + B_path = self.B_paths[index_B] + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + + # apply image transformation + A = self.transform_A(A_img) + B = self.transform_B(B_img) + + return {'real_A': A, 'real_B': B, 'path_A': A_path, 'path_B': B_path} + + def __len__(self): + """Return the total number of images in the dataset. + + As we have two datasets with potentially different number of images, + we take a maximum of + """ + return max(self.A_size, self.B_size) diff --git a/swapae/data/unaligned_lmdb_dataset.py b/swapae/data/unaligned_lmdb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd688bbfcbbc304d8a741cbc1afbaac08465f60 --- /dev/null +++ b/swapae/data/unaligned_lmdb_dataset.py @@ -0,0 +1,31 @@ +import random +import os.path +from swapae.data.base_dataset import BaseDataset +from swapae.data.lmdb_dataset import LMDBDataset +import swapae.util + + +class UnalignedLMDBDataset(BaseDataset): + def __init__(self, opt): + super().__init__(opt) + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' + + self.dataset_A = LMDBDataset(util.copyconf(opt, dataroot=self.dir_A)) + self.dataset_B = LMDBDataset(util.copyconf(opt, dataroot=self.dir_B)) + self.B_indices = list(range(len(self.dataset_B))) + + + def __len__(self): + return max(len(self.dataset_A), len(self.dataset_B)) + + def __getitem__(self, index): + if index == 0 and self.opt.isTrain: + random.shuffle(self.B_indices) + + result = self.dataset_A.__getitem__(index % len(self.dataset_A)) + B_index = self.B_indices[index % len(self.dataset_B)] + B_result = self.dataset_B.__getitem__(B_index) + result["real_B"] = B_result["real_A"] + result["path_B"] = B_result["path_A"] + return result diff --git a/swapae/evaluation/__init__.py b/swapae/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2440773c2c6edb6c89e4d1027a1ca83445183f --- /dev/null +++ b/swapae/evaluation/__init__.py @@ -0,0 +1,7 @@ +from swapae.evaluation.base_evaluator import BaseEvaluator +from swapae.evaluation.group_evaluator import GroupEvaluator + + +def get_option_setter(): + return GroupEvaluator.modify_commandline_options + diff --git a/swapae/evaluation/base_evaluator.py b/swapae/evaluation/base_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..dd95ebfc0a37bee8158f91e0f273713b523130fa --- /dev/null +++ b/swapae/evaluation/base_evaluator.py @@ -0,0 +1,24 @@ +import os + + +class BaseEvaluator(): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def __init__(self, opt, target_phase): + super().__init__() + self.opt = opt + self.target_phase = target_phase + + def output_dir(self): + evaluator_name = str(type(self).__name__).lower().replace('evaluator', '') + expr_name = self.opt.name + if self.opt.isTrain: + result_dir = os.path.join(self.opt.checkpoints_dir, expr_name, "snapshots") + else: + result_dir = os.path.join(self.opt.result_dir, expr_name, evaluator_name) + return result_dir + + def evaluate(self, model, dataset, nsteps=None): + pass diff --git a/swapae/evaluation/group_evaluator.py b/swapae/evaluation/group_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..b9022485331a761eecf8c82ce1bee00ff178444c --- /dev/null +++ b/swapae/evaluation/group_evaluator.py @@ -0,0 +1,70 @@ +import torch +from swapae.evaluation.base_evaluator import BaseEvaluator +import swapae.util as util + +def find_evaluator_using_name(filename): + target_class_name = filename + module_name = 'swapae.evaluation.' + filename + eval_class = util.find_class_in_module(target_class_name, module_name) + + assert issubclass(eval_class, BaseEvaluator), \ + "Class %s should be a subclass of BaseEvaluator" % eval_class + + return eval_class + + +def find_evaluator_classes(opt): + if len(opt.evaluation_metrics) == 0: + return [] + + eval_metrics = opt.evaluation_metrics.split(",") + + all_classes = [] + target_phases = [] + for metric in eval_metrics: + if metric.startswith("train"): + target_phases.append("train") + metric = metric[len("train"):] + elif metric.startswith("test"): + target_phases.append("test") + metric = metric[len("test"):] + else: + target_phases.append("test") + + metric_class = find_evaluator_using_name("%s_evaluator" % metric) + all_classes.append(metric_class) + + return all_classes, target_phases + + +class GroupEvaluator(BaseEvaluator): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--evaluation_metrics", default="structure_style_grid_generation") + + opt, _ = parser.parse_known_args() + evaluator_classes, _ = find_evaluator_classes(opt) + + for eval_class in evaluator_classes: + parser = eval_class.modify_commandline_options(parser, is_train) + + return parser + + def __init__(self, opt, target_phase=None): + super().__init__(opt, target_phase=None) + self.opt = opt + evaluator_classes, target_phases = find_evaluator_classes(opt) + self.evaluators = [cls(opt, target_phase=phs) for cls, phs in zip(evaluator_classes, target_phases)] + + def evaluate(self, model, dataset, nsteps=None): + original_phase = dataset.phase + metrics = {} + for i, evaluator in enumerate(self.evaluators): + print("Entering evaluation using %s on %s images" % (type(evaluator).__name__, evaluator.target_phase)) + dataset.set_phase(evaluator.target_phase) + with torch.no_grad(): + new_metrics = evaluator.evaluate(model, dataset, nsteps) + metrics.update(new_metrics) + print("Finished evaluation of %s" % type(evaluator).__name__) + dataset.set_phase(original_phase) + return metrics diff --git a/swapae/evaluation/none_evaluator.py b/swapae/evaluation/none_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0adff233e7f2c5e67c16f0d779f3888a511078 --- /dev/null +++ b/swapae/evaluation/none_evaluator.py @@ -0,0 +1,14 @@ +from swapae.evaluation.base_evaluator import BaseEvaluator + + +class NoneEvaluator(BaseEvaluator): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def __init__(self, opt, target_phase): + super().__init__(opt, target_phase) + self.opt = opt + + def evaluate(self, model, dataset, nsteps): + return {} diff --git a/swapae/evaluation/simple_swapping_evaluator.py b/swapae/evaluation/simple_swapping_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..7a990d7740ec37c3fdebf55bcb6a3b5edb4fdcf5 --- /dev/null +++ b/swapae/evaluation/simple_swapping_evaluator.py @@ -0,0 +1,67 @@ +import os +import torchvision.transforms as transforms +from PIL import Image +from swapae.evaluation import BaseEvaluator +from swapae.data.base_dataset import get_transform +import swapae.util as util + + +class SimpleSwappingEvaluator(BaseEvaluator): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--input_structure_image", required=True, type=str) + parser.add_argument("--input_texture_image", required=True, type=str) + parser.add_argument("--texture_mix_alphas", type=float, nargs='+', + default=[1.0], + help="Performs interpolation of the texture image." + "If set to 1.0, it performs full swapping." + "If set to 0.0, it performs direct reconstruction" + ) + + opt, _ = parser.parse_known_args() + dataroot = os.path.dirname(opt.input_structure_image) + + # dataroot and dataset_mode are ignored in SimpleSwapplingEvaluator. + # Just set it to the directory that contains the input structure image. + parser.set_defaults(dataroot=dataroot, dataset_mode="imagefolder") + + return parser + + def load_image(self, path): + path = os.path.expanduser(path) + img = Image.open(path).convert('RGB') + transform = get_transform(self.opt) + tensor = transform(img).unsqueeze(0) + return tensor + + def evaluate(self, model, dataset, nsteps=None): + structure_image = self.load_image(self.opt.input_structure_image) + texture_image = self.load_image(self.opt.input_texture_image) + os.makedirs(self.output_dir(), exist_ok=True) + + model(sample_image=structure_image, command="fix_noise") + structure_code, source_texture_code = model( + structure_image, command="encode") + _, target_texture_code = model(texture_image, command="encode") + + alphas = self.opt.texture_mix_alphas + for alpha in alphas: + texture_code = util.lerp( + source_texture_code, target_texture_code, alpha) + + output_image = model(structure_code, texture_code, command="decode") + output_image = transforms.ToPILImage()( + (output_image[0].clamp(-1.0, 1.0) + 1.0) * 0.5) + + output_name = "%s_%s_%.2f.png" % ( + os.path.splitext(os.path.basename(self.opt.input_structure_image))[0], + os.path.splitext(os.path.basename(self.opt.input_texture_image))[0], + alpha + ) + + output_path = os.path.join(self.output_dir(), output_name) + + output_image.save(output_path) + print("Saved at " + output_path) + + return {} diff --git a/swapae/evaluation/structure_style_grid_generation_evaluator.py b/swapae/evaluation/structure_style_grid_generation_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e3bc7ec6b3299925998673f034499d840a06b3d1 --- /dev/null +++ b/swapae/evaluation/structure_style_grid_generation_evaluator.py @@ -0,0 +1,87 @@ +import os +import torch +from swapae.evaluation import BaseEvaluator +import swapae.util as util +import numpy as np +from PIL import Image + + +class StructureStyleGridGenerationEvaluator(BaseEvaluator): + """ generate swapping images and save to disk """ + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def create_webpage(self, nsteps): + nsteps = self.opt.resume_iter if nsteps is None else nsteps + savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) + os.makedirs(savedir, exist_ok=True) + webpage_title = "%s. iter=%s. phase=%s" % \ + (self.opt.name, str(nsteps), self.target_phase) + self.webpage = util.HTML(savedir, webpage_title) + + def add_to_webpage(self, images, filenames, tile=1): + converted_images = [] + for image in images: + if isinstance(image, list): + image = torch.stack(image, dim=0).flatten(0, 1) + image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile))) + converted_images.append(image) + + self.webpage.add_images(converted_images, + filenames) + print("saved %s" % str(filenames)) + #self.webpage.save() + + def evaluate(self, model, dataset, nsteps=None): + self.create_webpage(nsteps) + + structure_images, style_images = {}, {} + for i, data_i in enumerate(dataset): + bs = data_i["real_A"].size(0) + #sp, gl = model(data_i["real_A"].cuda(), command="encode") + + for j in range(bs): + image = data_i["real_A"][j:j+1] + path = data_i["path_A"][j] + imagename = os.path.splitext(os.path.basename(path))[0] + if "/structure/" in path: + structure_images[imagename] = image + else: + style_images[imagename] = image + + gls = [] + style_paths = list(style_images.keys()) + for style_path in style_paths: + style_image = style_images[style_path].cuda() + gls.append(model(style_image, command="encode")[1]) + + sps = [] + structure_paths = list(structure_images.keys()) + for structure_path in structure_paths: + structure_image = structure_images[structure_path].cuda() + sps.append(model(structure_image, command="encode")[0]) + + # top row to show the input images + blank_image = style_images[style_paths[0]] * 0.0 + 1.0 + self.add_to_webpage([blank_image] + [style_images[style_path] for style_path in style_paths], + ["blank.png"] + [style_path + ".png" for style_path in style_paths], + tile=1) + + # swapping + for i, structure_path in enumerate(structure_paths): + structure_image = structure_images[structure_path] + swaps = [] + filenames = [] + for j, style_path in enumerate(style_paths): + swaps.append(model(sps[i], gls[j], command="decode")) + filenames.append(structure_path + "_" + style_path + ".png") + self.add_to_webpage([structure_image] + swaps, + [structure_path + ".png"] + filenames, + tile=1) + + self.webpage.save() + return {} + + + diff --git a/swapae/evaluation/swap_generation_from_arranged_result_evaluator.py b/swapae/evaluation/swap_generation_from_arranged_result_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..d08a668f5225406fa94675fecd734506409bff36 --- /dev/null +++ b/swapae/evaluation/swap_generation_from_arranged_result_evaluator.py @@ -0,0 +1,123 @@ +import glob +import torchvision.transforms as transforms +import os +import torch +from swapae.evaluation import BaseEvaluator +import swapae.util as util +from PIL import Image + + +class InputDataset(torch.utils.data.Dataset): + def __init__(self, dataroot): + structure_images = sorted(glob.glob(os.path.join(dataroot, "input_structure", "*.png"))) + style_images = sorted(glob.glob(os.path.join(dataroot, "input_style", "*.png"))) + + for structure_path, style_path in zip(structure_images, style_images): + assert structure_path.replace("structure", "style") == style_path, \ + "%s and %s do not match" % (structure_path, style_path) + + assert len(structure_images) == len(style_images) + print("found %d images at %s" % (len(structure_images), dataroot)) + + self.structure_images = structure_images + self.style_images = style_images + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ] + ) + + def __len__(self): + return len(self.structure_images) + + def __getitem__(self, idx): + structure_image = self.transform(Image.open(self.structure_images[idx]).convert('RGB')) + style_image = self.transform(Image.open(self.style_images[idx]).convert('RGB')) + return {'structure': structure_image, + 'style': style_image, + 'path': self.structure_images[idx]} + + +class SwapGenerationFromArrangedResultEvaluator(BaseEvaluator): + """ Given two directories containing input structure and style (texture) + images, respectively, generate reconstructed and swapped images. + The input directories should contain the same set of image filenames. + It differs from StructureStyleGridGenerationEvaluator, which creates + N^2 outputs (i.e. swapping of all possible pairs between the structure and + style images). + """ + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def image_save_dir(self, nsteps): + return os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps), "images") + + def create_webpage(self, nsteps): + if nsteps is None: + nsteps = self.opt.resume_iter + elif isinstance(nsteps, int): + nsteps = str(round(nsteps / 1000)) + "k" + savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) + os.makedirs(savedir, exist_ok=True) + webpage_title = "%s. iter=%s. phase=%s" % \ + (self.opt.name, str(nsteps), self.target_phase) + self.webpage = util.HTML(savedir, webpage_title) + + def add_to_webpage(self, images, filenames, tile=1): + converted_images = [] + for image in images: + if isinstance(image, list): + image = torch.stack(image, dim=0).flatten(0, 1) + image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile))) + converted_images.append(image) + + self.webpage.add_images(converted_images, + filenames) + print("saved %s" % str(filenames)) + #self.webpage.save() + + def set_num_test_images(self, num_images): + self.num_test_images = num_images + + def evaluate(self, model, dataset, nsteps=None): + input_dataset = torch.utils.data.DataLoader( + InputDataset(self.opt.dataroot), + batch_size=1, + shuffle=False, drop_last=False, num_workers=0 + ) + + self.num_test_images = None + self.create_webpage(nsteps) + image_num = 0 + for i, data_i in enumerate(input_dataset): + structure = data_i["structure"].cuda() + style = data_i["style"].cuda() + path = data_i["path"][0] + path = os.path.basename(path) + #if "real_B" in data_i: + # image = torch.cat([image, data_i["real_B"].cuda()], dim=0) + # paths = paths + data_i["path_B"] + sp, gl = model(structure, command="encode") + rec = model(sp, gl, command="decode") + + _, gl = model(style, command="encode") + swapped = model(sp, gl, command="decode") + + self.add_to_webpage([structure, style, rec, swapped], + ["%s_structure.png" % (path), + "%s_style.png" % (path), + "%s_rec.png" % (path), + "%s_swap.png" % (path)], + tile=1) + image_num += 1 + if self.num_test_images is not None and self.num_test_images <= image_num: + self.webpage.save() + return {} + + self.webpage.save() + return {} + + + diff --git a/swapae/evaluation/swap_visualization_evaluator.py b/swapae/evaluation/swap_visualization_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..73989dea8025f950d12dd0e66cafebce884eb488 --- /dev/null +++ b/swapae/evaluation/swap_visualization_evaluator.py @@ -0,0 +1,91 @@ +import os +from PIL import Image +import numpy as np +import torch +from swapae.evaluation import BaseEvaluator +import swapae.util as util + + +class SwapVisualizationEvaluator(BaseEvaluator): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--swap_num_columns", type=int, default=4, + help="number of images to be shown in the swap visualization grid. Setting this value will result in 4x4 swapping grid, with additional row and col for showing original images.") + parser.add_argument("--swap_num_images", type=int, default=16, + help="total number of images to perform swapping. In the end, (swap_num_images / swap_num_columns) grid will be saved to disk") + return parser + + def gather_images(self, dataset): + all_images = [] + num_images_to_gather = max(self.opt.swap_num_columns, self.opt.num_gpus) + exhausted = False + while len(all_images) < num_images_to_gather: + try: + data = next(dataset) + except StopIteration: + print("Exhausted the dataset at %s" % (self.opt.dataroot)) + exhausted = True + break + for i in range(data["real_A"].size(0)): + all_images.append(data["real_A"][i:i+1]) + if "real_B" in data: + all_images.append(data["real_B"][i:i+1]) + if len(all_images) >= num_images_to_gather: + break + if len(all_images) == 0: + return None, None, True + return all_images, exhausted + + def generate_mix_grid(self, model, images): + sps, gls = [], [] + for image in images: + assert image.size(0) == 1 + sp, gl = model(image.expand(self.opt.num_gpus, -1, -1, -1), command="encode") + sp = sp[:1] + gl = gl[:1] + sps.append(sp) + gls.append(gl) + gl = torch.cat(gls, dim=0) + + def put_img(img, canvas, row, col): + h, w = img.shape[0], img.shape[1] + start_x = int(self.opt.load_size * col + (self.opt.load_size - w) * 0.5) + start_y = int(self.opt.load_size * row + (self.opt.load_size - h) * 0.5) + canvas[start_y:start_y + h, start_x: start_x + w] = img + grid_w = self.opt.load_size * (gl.size(0) + 1) + grid_h = self.opt.load_size * (gl.size(0) + 1) + grid_img = np.ones((grid_h, grid_w, 3), dtype=np.uint8) + #images_np = util.tensor2im(images, tile=False) + for i, image in enumerate(images): + image_np = util.tensor2im(image, tile=False)[0] + put_img(image_np, grid_img, 0, i + 1) + put_img(image_np, grid_img, i + 1, 0) + + for i, sp in enumerate(sps): + sp_for_current_row = sp.repeat(gl.size(0), 1, 1, 1) + mix_row = model(sp_for_current_row, gl, command="decode") + mix_row = util.tensor2im(mix_row, tile=False) + for j, mix in enumerate(mix_row): + put_img(mix, grid_img, i + 1, j + 1) + + final_grid = Image.fromarray(grid_img) + return final_grid + + def evaluate(self, model, dataset, nsteps): + nsteps = self.opt.resume_iter if nsteps is None else str(round(nsteps / 1000)) + "k" + savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) + os.makedirs(savedir, exist_ok=True) + webpage_title = "Swap Visualization of %s. iter=%s. phase=%s" % \ + (self.opt.name, str(nsteps), self.target_phase) + webpage = util.HTML(savedir, webpage_title) + num_repeats = int(np.ceil(self.opt.swap_num_images / max(self.opt.swap_num_columns, self.opt.num_gpus))) + for i in range(num_repeats): + images, should_break = self.gather_images(dataset) + if images is None: + break + mix_grid = self.generate_mix_grid(model, images) + webpage.add_images([mix_grid], ["%04d.png" % i]) + if should_break: + break + webpage.save() + return {} diff --git a/swapae/models/__init__.py b/swapae/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dffdbcc468ea04f6bc3ef55c4b578465b2a6bc85 --- /dev/null +++ b/swapae/models/__init__.py @@ -0,0 +1,109 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import os +import importlib +from swapae.models.base_model import BaseModel +import torch +from torch.nn.parallel import DataParallel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "swapae.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + instance.initialize() + multigpu_instance = MultiGPUModelWrapper(opt, instance) + print("model [%s] was created" % type(instance).__name__) + return multigpu_instance + + +class MultiGPUModelWrapper(): + def __init__(self, opt, model: BaseModel): + self.opt = opt + if opt.num_gpus > 0: + model = model.to('cuda:0') + self.parallelized_model = torch.nn.parallel.DataParallel(model) + self.parallelized_model(command="per_gpu_initialize") + self.singlegpu_model = self.parallelized_model.module + self.singlegpu_model(command="per_gpu_initialize") + + def get_parameters_for_mode(self, mode): + return self.singlegpu_model.get_parameters_for_mode(mode) + + def save(self, total_steps_so_far): + self.singlegpu_model.save(total_steps_so_far) + + def __call__(self, *args, **kwargs): + """ Calls are forwarded to __call__ of BaseModel through DataParallel, and corresponding methods specified by |command| will be called. Please see BaseModel.forward() to see how it is done. """ + return self.parallelized_model(*args, **kwargs) + + +class StateVariableStorage(): + pass + + +_state_variables = StateVariableStorage() +_state_variables.fix_noise = False + + +def fixed_noise(): + return _state_variables.fix_noise + + +def fix_noise(set=True): + _state_variables.fix_noise = set diff --git a/swapae/models/__pycache__/__init__.cpython-37.pyc b/swapae/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa11a3e46c6929f5b74b4c78a5d255281fd41900 Binary files /dev/null and b/swapae/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/swapae/models/__pycache__/__init__.cpython-38.pyc b/swapae/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d42006bbc27b9832db384c377317d69d81d1c8c1 Binary files /dev/null and b/swapae/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/swapae/models/__pycache__/base_model.cpython-37.pyc b/swapae/models/__pycache__/base_model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e51b6ebdf2e958d55e328e07d066ad74c67b53d Binary files /dev/null and b/swapae/models/__pycache__/base_model.cpython-37.pyc differ diff --git a/swapae/models/__pycache__/base_model.cpython-38.pyc b/swapae/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79152ebf5baa8401a0f76e509d13362564d185d Binary files /dev/null and b/swapae/models/__pycache__/base_model.cpython-38.pyc differ diff --git a/swapae/models/base_model.py b/swapae/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2a8d14caeb6ce5403c46bb4cf79126a353dc81 --- /dev/null +++ b/swapae/models/base_model.py @@ -0,0 +1,123 @@ +import os +import torch + + +class BaseModel(torch.nn.Module): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def __init__(self, opt): + super().__init__() + self.opt = opt + self.device = torch.device('cuda:0') if opt.num_gpus > 0 else torch.device('cpu') + + def initialize(self): + pass + + def per_gpu_initialize(self): + pass + + def compute_generator_losses(self, data_i): + return {} + + def compute_discriminator_losses(self, data_i): + return {} + + def get_visuals_for_snapshot(self, data_i): + return {} + + def get_parameters_for_mode(self, mode): + return {} + + def save(self, total_steps_so_far): + savedir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + checkpoint_name = "%dk_checkpoint.pth" % (total_steps_so_far // 1000) + savepath = os.path.join(savedir, checkpoint_name) + torch.save(self.state_dict(), savepath) + sympath = os.path.join(savedir, "latest_checkpoint.pth") + if os.path.exists(sympath): + os.remove(sympath) + os.symlink(checkpoint_name, sympath) + + def load(self): + if self.opt.isTrain and self.opt.pretrained_name is not None: + loaddir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) + else: + loaddir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + checkpoint_name = "%s_checkpoint.pth" % self.opt.resume_iter + checkpoint_path = os.path.join(loaddir, checkpoint_name) + if not os.path.exists(checkpoint_path): + print("\n\ncheckpoint %s does not exist!" % checkpoint_path) + assert self.opt.isTrain, "In test mode, the checkpoint file must exist" + print("Training will start from scratch") + return + state_dict = torch.load(checkpoint_path, + map_location=str(self.device)) + # self.load_state_dict(state_dict) + own_state = self.state_dict() + skip_all = False + for name, own_param in own_state.items(): + if not self.opt.isTrain and (name.startswith("D.") or name.startswith("Dpatch.")): + continue + if name not in state_dict: + print("Key %s does not exist in checkpoint. Skipping..." % name) + continue + # if name.startswith("C.net"): + # continue + param = state_dict[name] + if own_param.shape != param.shape: + message = "Key [%s]: Shape does not match the created model (%s) and loaded checkpoint (%s)" % (name, str(own_param.shape), str(param.shape)) + if skip_all: + print(message) + min_shape = [min(s1, s2) for s1, s2 in zip(own_param.shape, param.shape)] + ms = min_shape + if len(min_shape) == 1: + own_param[:ms[0]].copy_(param[:ms[0]]) + own_param[ms[0]:].copy_(own_param[ms[0]:] * 0) + elif len(min_shape) == 2: + own_param[:ms[0], :ms[1]].copy_(param[:ms[0], :ms[1]]) + own_param[ms[0]:, ms[1]:].copy_(own_param[ms[0]:, ms[1]:] * 0) + elif len(ms) == 4: + own_param[:ms[0], :ms[1], :ms[2], :ms[3]].copy_(param[:ms[0], :ms[1], :ms[2], :ms[3]]) + own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:].copy_(own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:] * 0) + else: + print("Skipping min_shape of %s" % str(ms)) + continue + userinput = input("%s. Force loading? (yes, no, all) " % (message)) + if userinput.lower() == "yes": + pass + elif userinput.lower() == "no": + #assert own_param.shape == param.shape + continue + elif userinput.lower() == "all": + skip_all = True + else: + raise ValueError(userinput) + min_shape = [min(s1, s2) for s1, s2 in zip(own_param.shape, param.shape)] + ms = min_shape + if len(min_shape) == 1: + own_param[:ms[0]].copy_(param[:ms[0]]) + own_param[ms[0]:].copy_(own_param[ms[0]:] * 0) + elif len(min_shape) == 2: + own_param[:ms[0], :ms[1]].copy_(param[:ms[0], :ms[1]]) + own_param[ms[0]:, ms[1]:].copy_(own_param[ms[0]:, ms[1]:] * 0) + elif len(ms) == 4: + own_param[:ms[0], :ms[1], :ms[2], :ms[3]].copy_(param[:ms[0], :ms[1], :ms[2], :ms[3]]) + own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:].copy_(own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:] * 0) + else: + print("Skipping min_shape of %s" % str(ms)) + continue + own_param.copy_(param) + print("checkpoint loaded from %s" % os.path.join(loaddir, checkpoint_name)) + + def forward(self, *args, command=None, **kwargs): + """ wrapper for multigpu training. BaseModel is expected to be + wrapped in nn.parallel.DataParallel, which distributes its call to + the BaseModel instance on each GPU """ + if command is not None: + method = getattr(self, command) + assert callable(method), "[%s] is not a method of %s" % (command, type(self).__name__) + return method(*args, **kwargs) + else: + raise ValueError(command) diff --git a/swapae/models/networks/__init__.py b/swapae/models/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0bf12b652f51078228ce9b899bbf8a18386c51 --- /dev/null +++ b/swapae/models/networks/__init__.py @@ -0,0 +1,46 @@ +import torch +import swapae.util as util +from swapae.models.networks.base_network import BaseNetwork + + +def find_network_using_name(target_network_name, filename): + target_class_name = target_network_name + filename + module_name = 'swapae.models.networks.' + filename + network = util.find_class_in_module(target_class_name, module_name) + + assert issubclass(network, BaseNetwork), \ + "Class %s should be a subclass of BaseNetwork" % network + + return network + + +def modify_commandline_options(parser, is_train): + opt, _ = parser.parse_known_args() + + netE_cls = find_network_using_name(opt.netE, 'encoder') + assert netE_cls is not None + parser = netE_cls.modify_commandline_options(parser, is_train) + + netG_cls = find_network_using_name(opt.netG, 'generator') + assert netG_cls is not None + parser = netG_cls.modify_commandline_options(parser, is_train) + + netD_cls = find_network_using_name(opt.netD, 'discriminator') + parser = netD_cls.modify_commandline_options(parser, is_train) + + if opt.netPatchD is not None: + netD_cls = find_network_using_name(opt.netPatchD, 'patch_discriminator') + assert netD_cls is not None + parser = netD_cls.modify_commandline_options(parser, is_train) + + return parser + + +def create_network(opt, network_name, mode, verbose=True): + if network_name is None: + return None + net_cls = find_network_using_name(network_name, mode) + net = net_cls(opt) + if verbose: + net.print_architecture(verbose=True) + return net diff --git a/swapae/models/networks/__pycache__/__init__.cpython-37.pyc b/swapae/models/networks/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6766939101527b26d84e85eaf452e07c950a8166 Binary files /dev/null and b/swapae/models/networks/__pycache__/__init__.cpython-37.pyc differ diff --git a/swapae/models/networks/__pycache__/__init__.cpython-38.pyc b/swapae/models/networks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80656d17832658ceba1c299dbb218f31ba0735ba Binary files /dev/null and b/swapae/models/networks/__pycache__/__init__.cpython-38.pyc differ diff --git a/swapae/models/networks/__pycache__/base_network.cpython-37.pyc b/swapae/models/networks/__pycache__/base_network.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4437dc6b2f548dd8035f47fdcfa73c0f964049 Binary files /dev/null and b/swapae/models/networks/__pycache__/base_network.cpython-37.pyc differ diff --git a/swapae/models/networks/__pycache__/base_network.cpython-38.pyc b/swapae/models/networks/__pycache__/base_network.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86233c12e140e3461fc0a1a990ad68e3f2a55d8d Binary files /dev/null and b/swapae/models/networks/__pycache__/base_network.cpython-38.pyc differ diff --git a/swapae/models/networks/__pycache__/loss.cpython-37.pyc b/swapae/models/networks/__pycache__/loss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e37b4b497cbe258ac0a99f0606b4d768e63ff618 Binary files /dev/null and b/swapae/models/networks/__pycache__/loss.cpython-37.pyc differ diff --git a/swapae/models/networks/__pycache__/patch_discriminator.cpython-37.pyc b/swapae/models/networks/__pycache__/patch_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba5783d2f36710038914cd717ff4a72a241b61fe Binary files /dev/null and b/swapae/models/networks/__pycache__/patch_discriminator.cpython-37.pyc differ diff --git a/swapae/models/networks/__pycache__/stylegan2_layers.cpython-37.pyc b/swapae/models/networks/__pycache__/stylegan2_layers.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d368a5289864628a05b575aecdf9b39bccecd76 Binary files /dev/null and b/swapae/models/networks/__pycache__/stylegan2_layers.cpython-37.pyc differ diff --git a/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc b/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14cbc5d19d7f5e1a9df9488dc31a946251d4e202 Binary files /dev/null and b/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc differ diff --git a/swapae/models/networks/base_network.py b/swapae/models/networks/base_network.py new file mode 100644 index 0000000000000000000000000000000000000000..02d33f713880be1c28b543865334de5ef9ee8586 --- /dev/null +++ b/swapae/models/networks/base_network.py @@ -0,0 +1,57 @@ +import torch + + +class BaseNetwork(torch.nn.Module): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def __init__(self, opt): + super().__init__() + self.opt = opt + + def print_architecture(self, verbose=False): + name = type(self).__name__ + result = '-------------------%s---------------------\n' % name + total_num_params = 0 + for i, (name, child) in enumerate(self.named_children()): + num_params = sum([p.numel() for p in child.parameters()]) + total_num_params += num_params + if verbose: + result += "%s: %3.3fM\n" % (name, (num_params / 1e6)) + for i, (name, grandchild) in enumerate(child.named_children()): + num_params = sum([p.numel() for p in grandchild.parameters()]) + if verbose: + result += "\t%s: %3.3fM\n" % (name, (num_params / 1e6)) + result += '[Network %s] Total number of parameters : %.3f M\n' % (name, total_num_params / 1e6) + result += '-----------------------------------------------\n' + print(result) + + def set_requires_grad(self, requires_grad): + for param in self.parameters(): + param.requires_grad = requires_grad + + def collect_parameters(self, name): + params = [] + for m in self.modules(): + if type(m).__name__ == name: + params += list(m.parameters()) + return params + + def fix_and_gather_noise_parameters(self): + params = [] + device = next(self.parameters()).device + for m in self.modules(): + if type(m).__name__ == "NoiseInjection": + assert m.image_size is not None, "One forward call should be made to determine size of noise parameters" + m.fixed_noise = torch.nn.Parameter(torch.randn(m.image_size[0], 1, m.image_size[2], m.image_size[3], device=device)) + params.append(m.fixed_noise) + return params + + def remove_noise_parameters(self, name): + for m in self.modules(): + if type(m).__name__ == "NoiseInjection": + m.fixed_noise = None + + def forward(self, x): + return x diff --git a/swapae/models/networks/classifier.py b/swapae/models/networks/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..86b2eab0c31ef8999536232eaf405229dd3b0953 --- /dev/null +++ b/swapae/models/networks/classifier.py @@ -0,0 +1,28 @@ +import torch +from swapae.models.networks import BaseNetwork +from swapae.models.networks.pyramidnet import PyramidNet + + +class PyramidNetClassifier(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--pyramid_alpha", type=int, default=240) + parser.add_argument("--pyramid_depth", type=int, default=200) + return parser + + def __init__(self, opt): + super().__init__(opt) + assert "cifar" in opt.dataset_mode + self.net = PyramidNet( + opt.dataset_mode, depth=opt.pyramid_depth, alpha=opt.pyramid_alpha, + num_classes=opt.num_classes, bottleneck=True) + + mean = torch.tensor([x / 127.5 - 1.0 for x in [125.3, 123.0, 113.9]], dtype=torch.float) + std = torch.tensor([x / 127.5 for x in [63.0, 62.1, 66.7]], dtype=torch.float) + self.register_buffer("mean", mean[None, :, None, None]) + self.register_buffer("std", std[None, :, None, None]) + + def forward(self, x): + x = (x - self.mean) / self.std + return self.net(x) + diff --git a/swapae/models/networks/discriminator.py b/swapae/models/networks/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..08a2191fdbef001eef075826c58ad6fdb81dc588 --- /dev/null +++ b/swapae/models/networks/discriminator.py @@ -0,0 +1,32 @@ +from swapae.models.networks import BaseNetwork +from swapae.models.networks.stylegan2_layers import Discriminator as OriginalStyleGAN2Discriminator + + +class StyleGAN2Discriminator(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--netD_scale_capacity", default=1.0, type=float) + return parser + + def __init__(self, opt): + super().__init__(opt) + self.stylegan2_D = OriginalStyleGAN2Discriminator( + opt.crop_size, + 2.0 * opt.netD_scale_capacity, + blur_kernel=[1, 3, 3, 1] if self.opt.use_antialias else [1] + ) + + def forward(self, x): + pred = self.stylegan2_D(x) + return pred + + def get_features(self, x): + return self.stylegan2_D.get_features(x) + + def get_pred_from_features(self, feat, label): + assert label is None + feat = feat.flatten(1) + out = self.stylegan2_D.final_linear(feat) + return out + + diff --git a/swapae/models/networks/encoder.py b/swapae/models/networks/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..28d6da385d52ee0468389e2a93bc8255baecc21f --- /dev/null +++ b/swapae/models/networks/encoder.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +import swapae.util as util +from swapae.models.networks import BaseNetwork +from swapae.models.networks.stylegan2_layers import ResBlock, ConvLayer, ToRGB, EqualLinear, Blur, Upsample, make_kernel +from swapae.models.networks.stylegan2_op import upfirdn2d + + +class ToSpatialCode(torch.nn.Module): + def __init__(self, inch, outch, scale): + super().__init__() + hiddench = inch // 2 + self.conv1 = ConvLayer(inch, hiddench, 1, activate=True, bias=True) + self.conv2 = ConvLayer(hiddench, outch, 1, activate=False, bias=True) + self.scale = scale + self.upsample = Upsample([1, 3, 3, 1], 2) + self.blur = Blur([1, 3, 3, 1], pad=(2, 1)) + self.register_buffer('kernel', make_kernel([1, 3, 3, 1])) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + for i in range(int(np.log2(self.scale))): + x = self.upsample(x) + return x + + +class StyleGAN2ResnetEncoder(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--netE_scale_capacity", default=1.0, type=float) + parser.add_argument("--netE_num_downsampling_sp", default=4, type=int) + parser.add_argument("--netE_num_downsampling_gl", default=2, type=int) + parser.add_argument("--netE_nc_steepness", default=2.0, type=float) + return parser + + def __init__(self, opt): + super().__init__(opt) + + # If antialiasing is used, create a very lightweight Gaussian kernel. + blur_kernel = [1, 2, 1] if self.opt.use_antialias else [1] + + self.add_module("FromRGB", ConvLayer(3, self.nc(0), 1)) + + self.DownToSpatialCode = nn.Sequential() + for i in range(self.opt.netE_num_downsampling_sp): + self.DownToSpatialCode.add_module( + "ResBlockDownBy%d" % (2 ** i), + ResBlock(self.nc(i), self.nc(i + 1), blur_kernel, + reflection_pad=True) + ) + + # Spatial Code refers to the Structure Code, and + # Global Code refers to the Texture Code of the paper. + nchannels = self.nc(self.opt.netE_num_downsampling_sp) + self.add_module( + "ToSpatialCode", + nn.Sequential( + ConvLayer(nchannels, nchannels, 1, activate=True, bias=True), + ConvLayer(nchannels, self.opt.spatial_code_ch, kernel_size=1, + activate=False, bias=True) + ) + ) + + self.DownToGlobalCode = nn.Sequential() + for i in range(self.opt.netE_num_downsampling_gl): + idx_from_beginning = self.opt.netE_num_downsampling_sp + i + self.DownToGlobalCode.add_module( + "ConvLayerDownBy%d" % (2 ** idx_from_beginning), + ConvLayer(self.nc(idx_from_beginning), + self.nc(idx_from_beginning + 1), kernel_size=3, + blur_kernel=[1], downsample=True, pad=0) + ) + + nchannels = self.nc(self.opt.netE_num_downsampling_sp + + self.opt.netE_num_downsampling_gl) + self.add_module( + "ToGlobalCode", + nn.Sequential( + EqualLinear(nchannels, self.opt.global_code_ch) + ) + ) + + def nc(self, idx): + nc = self.opt.netE_nc_steepness ** (5 + idx) + nc = nc * self.opt.netE_scale_capacity + # nc = min(self.opt.global_code_ch, int(round(nc))) + return round(nc) + + def forward(self, x, extract_features=False): + x = self.FromRGB(x) + midpoint = self.DownToSpatialCode(x) + sp = self.ToSpatialCode(midpoint) + + if extract_features: + padded_midpoint = F.pad(midpoint, (1, 0, 1, 0), mode='reflect') + feature = self.DownToGlobalCode[0](padded_midpoint) + assert feature.size(2) == sp.size(2) // 2 and \ + feature.size(3) == sp.size(3) // 2 + feature = F.interpolate( + feature, size=(7, 7), mode='bilinear', align_corners=False) + + x = self.DownToGlobalCode(midpoint) + x = x.mean(dim=(2, 3)) + gl = self.ToGlobalCode(x) + sp = util.normalize(sp) + gl = util.normalize(gl) + if extract_features: + return sp, gl, feature + else: + return sp, gl diff --git a/swapae/models/networks/generator.py b/swapae/models/networks/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbe8fe1e36420a873178c03de237d6f0b824b72 --- /dev/null +++ b/swapae/models/networks/generator.py @@ -0,0 +1,161 @@ +import math +import torch +import swapae.util as util +import torch.nn.functional as F +from swapae.models.networks import BaseNetwork +from swapae.models.networks.stylegan2_layers import ConvLayer, ToRGB, EqualLinear, StyledConv + + +class UpsamplingBlock(torch.nn.Module): + def __init__(self, inch, outch, styledim, + blur_kernel=[1, 3, 3, 1], use_noise=False): + super().__init__() + self.inch, self.outch, self.styledim = inch, outch, styledim + self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=True, + blur_kernel=blur_kernel, use_noise=use_noise) + self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False, + use_noise=use_noise) + + def forward(self, x, style): + return self.conv2(self.conv1(x, style), style) + + +class ResolutionPreservingResnetBlock(torch.nn.Module): + def __init__(self, opt, inch, outch, styledim): + super().__init__() + self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=False) + self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False) + if inch != outch: + self.skip = ConvLayer(inch, outch, 1, activate=False, bias=False) + else: + self.skip = torch.nn.Identity() + + def forward(self, x, style): + skip = self.skip(x) + res = self.conv2(self.conv1(x, style), style) + return (skip + res) / math.sqrt(2) + + +class UpsamplingResnetBlock(torch.nn.Module): + def __init__(self, inch, outch, styledim, blur_kernel=[1, 3, 3, 1], use_noise=False): + super().__init__() + self.inch, self.outch, self.styledim = inch, outch, styledim + self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=True, blur_kernel=blur_kernel, use_noise=use_noise) + self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False, use_noise=use_noise) + if inch != outch: + self.skip = ConvLayer(inch, outch, 1, activate=True, bias=True) + else: + self.skip = torch.nn.Identity() + + def forward(self, x, style): + skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False) + res = self.conv2(self.conv1(x, style), style) + return (skip + res) / math.sqrt(2) + + +class GeneratorModulation(torch.nn.Module): + def __init__(self, styledim, outch): + super().__init__() + self.scale = EqualLinear(styledim, outch) + self.bias = EqualLinear(styledim, outch) + + def forward(self, x, style): + if style.ndimension() <= 2: + return x * (1 * self.scale(style)[:, :, None, None]) + self.bias(style)[:, :, None, None] + else: + style = F.interpolate(style, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False) + return x * (1 * self.scale(style)) + self.bias(style) + + +class StyleGAN2ResnetGenerator(BaseNetwork): + """ The Generator (decoder) architecture described in Figure 18 of + Swapping Autoencoder (https://arxiv.org/abs/2007.00653). + + At high level, the architecture consists of regular and + upsampling residual blocks to transform the structure code into an RGB + image. The global code is applied at each layer as modulation. + + Here's more detailed architecture: + + 1. SpatialCodeModulation: First of all, modulate the structure code + with the global code. + 2. HeadResnetBlock: resnets at the resolution of the structure code, + which also incorporates modulation from the global code. + 3. UpsamplingResnetBlock: resnets that upsamples by factor of 2 until + the resolution of the output RGB image, along with the global code + modulation. + 4. ToRGB: Final layer that transforms the output into 3 channels (RGB). + + Each components of the layers borrow heavily from StyleGAN2 code, + implemented by Seonghyeon Kim. + https://github.com/rosinality/stylegan2-pytorch + """ + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--netG_scale_capacity", default=1.0, type=float) + parser.add_argument( + "--netG_num_base_resnet_layers", + default=2, type=int, + help="The number of resnet layers before the upsampling layers." + ) + parser.add_argument("--netG_use_noise", type=util.str2bool, nargs='?', const=True, default=True) + parser.add_argument("--netG_resnet_ch", type=int, default=256) + + return parser + + def __init__(self, opt): + super().__init__(opt) + num_upsamplings = opt.netE_num_downsampling_sp + blur_kernel = [1, 3, 3, 1] if opt.use_antialias else [1] + + self.global_code_ch = opt.global_code_ch + opt.num_classes + + self.add_module( + "SpatialCodeModulation", + GeneratorModulation(self.global_code_ch, opt.spatial_code_ch)) + + in_channel = opt.spatial_code_ch + for i in range(opt.netG_num_base_resnet_layers): + # gradually increase the number of channels + out_channel = (i + 1) / opt.netG_num_base_resnet_layers * self.nf(0) + out_channel = max(opt.spatial_code_ch, round(out_channel)) + layer_name = "HeadResnetBlock%d" % i + new_layer = ResolutionPreservingResnetBlock( + opt, in_channel, out_channel, self.global_code_ch) + self.add_module(layer_name, new_layer) + in_channel = out_channel + + for j in range(num_upsamplings): + out_channel = self.nf(j + 1) + layer_name = "UpsamplingResBlock%d" % (2 ** (4 + j)) + new_layer = UpsamplingResnetBlock( + in_channel, out_channel, self.global_code_ch, + blur_kernel, opt.netG_use_noise) + self.add_module(layer_name, new_layer) + in_channel = out_channel + + last_layer = ToRGB(out_channel, self.global_code_ch, + blur_kernel=blur_kernel) + self.add_module("ToRGB", last_layer) + + def nf(self, num_up): + ch = 128 * (2 ** (self.opt.netE_num_downsampling_sp - num_up)) + ch = int(min(512, ch) * self.opt.netG_scale_capacity) + return ch + + def forward(self, spatial_code, global_code): + spatial_code = util.normalize(spatial_code) + global_code = util.normalize(global_code) + + x = self.SpatialCodeModulation(spatial_code, global_code) + for i in range(self.opt.netG_num_base_resnet_layers): + resblock = getattr(self, "HeadResnetBlock%d" % i) + x = resblock(x, global_code) + + for j in range(self.opt.netE_num_downsampling_sp): + key_name = 2 ** (4 + j) + upsampling_layer = getattr(self, "UpsamplingResBlock%d" % key_name) + x = upsampling_layer(x, global_code) + rgb = self.ToRGB(x, global_code, None) + + return rgb diff --git a/swapae/models/networks/inception.py b/swapae/models/networks/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdbbf2d18b45ff7f1737e298eb6be7fa720801c --- /dev/null +++ b/swapae/models/networks/inception.py @@ -0,0 +1,326 @@ +# Code from https://github.com/mseitzer/pytorch-fid/blob/master/inception.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + #if self.normalize_input: + # assert x.min() >= -0.001 and x.max() <= 1.001, "min %f, max %f is out of range" % (x.min(), x.max()) + + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + #if idx == 1 and idx in self.output_blocks: # For Block 1, return the activations before maxpooling + # for idx2, layer in enumerate(block): + # x = layer(x) + # if idx2 == len(block) - 1: + # outp.append(x) + #else: + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = models.inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/swapae/models/networks/loss.py b/swapae/models/networks/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1dabb7a292128caf495b57ee00404df7c4dc7ee6 --- /dev/null +++ b/swapae/models/networks/loss.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import swapae.util as util +from .stylegan2_layers import Downsample + + +def gan_loss(pred, should_be_classified_as_real): + bs = pred.size(0) + if should_be_classified_as_real: + return F.softplus(-pred).view(bs, -1).mean(dim=1) + else: + return F.softplus(pred).view(bs, -1).mean(dim=1) + + +def feature_matching_loss(xs, ys, equal_weights=False, num_layers=6): + loss = 0.0 + for i, (x, y) in enumerate(zip(xs[:num_layers], ys[:num_layers])): + if equal_weights: + weight = 1.0 / min(num_layers, len(xs)) + else: + weight = 1 / (2 ** (min(num_layers, len(xs)) - i)) + loss = loss + (x - y).abs().flatten(1).mean(1) * weight + return loss + + +class IntraImageNCELoss(nn.Module): + def __init__(self, opt): + super().__init__() + self.opt = opt + self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean') + + def forward(self, query, target): + num_locations = min(query.size(2) * query.size(3), self.opt.intraimage_num_locations) + bs = query.size(0) + patch_ids = torch.randperm(num_locations, device=query.device) + + query = query.flatten(2, 3) + target = target.flatten(2, 3) + + # both query and target are of size B x C x N + query = query[:, :, patch_ids] + target = target[:, :, patch_ids] + + cosine_similarity = torch.bmm(query.transpose(1, 2), target) + cosine_similarity = cosine_similarity.flatten(0, 1) + target_label = torch.arange(num_locations, dtype=torch.long, device=query.device).repeat(bs) + loss = self.cross_entropy_loss(cosine_similarity / 0.07, target_label) + return loss + + +class VGG16Loss(torch.nn.Module): + def __init__(self): + super().__init__() + self.vgg_convs = torchvision.models.vgg16(pretrained=True).features + self.register_buffer('mean', + torch.tensor([0.485, 0.456, 0.406])[None, :, None, None] - 0.5) + self.register_buffer('stdev', + torch.tensor([0.229, 0.224, 0.225])[None, :, None, None] * 2) + self.downsample = Downsample([1, 2, 1], factor=2) + + def copy_section(self, source, start, end): + slice = torch.nn.Sequential() + for i in range(start, end): + slice.add_module(str(i), source[i]) + return slice + + def vgg_forward(self, x): + x = (x - self.mean) / self.stdev + features = [] + for name, layer in self.vgg_convs.named_children(): + if "MaxPool2d" == type(layer).__name__: + features.append(x) + if len(features) == 3: + break + x = self.downsample(x) + else: + x = layer(x) + return features + + def forward(self, x, y): + y = y.detach() + loss = 0 + weights = [1 / 32, 1 / 16, 1 / 8, 1 / 4, 1.0] + #weights = [1] * 5 + total_weights = 0.0 + for i, (xf, yf) in enumerate(zip(self.vgg_forward(x), self.vgg_forward(y))): + loss += F.l1_loss(xf, yf) * weights[i] + total_weights += weights[i] + return loss / total_weights + + +class NCELoss(nn.Module): + def __init__(self): + super().__init__() + self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean') + + def forward(self, query, target, negatives): + query = util.normalize(query.flatten(1)) + target = util.normalize(target.flatten(1)) + negatives = util.normalize(negatives.flatten(1)) + bs = query.size(0) + sim_pos = (query * target).sum(dim=1, keepdim=True) + sim_neg = torch.mm(query, negatives.transpose(0, 1)) + all_similarity = torch.cat([sim_pos, sim_neg], axis=1) / 0.07 + #sim_target = util.compute_similarity_logit(query, target) + #sim_target = torch.mm(query, target.transpose(0, 1)) / 0.07 + #sim_query = util.compute_similarity_logit(query, query) + #util.set_diag_(sim_query, -20.0) + + #all_similarity = torch.cat([sim_target, sim_query], axis=1) + + #target_label = torch.arange(bs, dtype=torch.long, + # device=query.device) + target_label = torch.zeros(bs, dtype=torch.long, device=query.device) + loss = self.cross_entropy_loss(all_similarity, + target_label) + return loss + + +class ScaleInvariantReconstructionLoss(nn.Module): + def forward(self, query, target): + query_flat = query.transpose(1, 3) + target_flat = target.transpose(1, 3) + dist = 1.0 - torch.bmm( + query_flat[:, :, :, None, :].flatten(0, 2), + target_flat[:, :, :, :, None].flatten(0, 2), + ) + + target_spatially_flat = target.flatten(1, 2) + num_samples = min(target_spatially_flat.size(1), 64) + random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device) + randomly_sampled = target_spatially_flat[:, random_indices] + random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device) + another_random_sample = target_spatially_flat[:, random_indices] + + random_similarity = torch.bmm( + randomly_sampled[:, :, None, :].flatten(0, 1), + torch.flip(another_random_sample, [0])[:, :, :, None].flatten(0, 1) + ) + + return dist.mean() + random_similarity.clamp(min=0.0).mean() diff --git a/swapae/models/networks/patch_discriminator.py b/swapae/models/networks/patch_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..327903f5751a69c8cd352a36dcdcf5c4fdce38db --- /dev/null +++ b/swapae/models/networks/patch_discriminator.py @@ -0,0 +1,255 @@ +from collections import OrderedDict +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import swapae.util as util +from swapae.models.networks import BaseNetwork +from swapae.models.networks.stylegan2_layers import ConvLayer, ResBlock, EqualLinear + + +class BasePatchDiscriminator(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--netPatchD_scale_capacity", default=4.0, type=float) + parser.add_argument("--netPatchD_max_nc", default=256 + 128, type=int) + parser.add_argument("--patch_size", default=128, type=int) + parser.add_argument("--max_num_tiles", default=8, type=int) + parser.add_argument("--patch_random_transformation", + type=util.str2bool, nargs='?', const=True, default=False) + return parser + + def __init__(self, opt): + super().__init__(opt) + #self.visdom = util.Visualizer(opt) + + def needs_regularization(self): + return False + + def extract_features(self, patches): + raise NotImplementedError() + + def discriminate_features(self, feature1, feature2): + raise NotImplementedError() + + def apply_random_transformation(self, patches): + B, ntiles, C, H, W = patches.size() + patches = patches.view(B * ntiles, C, H, W) + before = patches + transformer = util.RandomSpatialTransformer(self.opt, B * ntiles) + patches = transformer.forward_transform(patches, (self.opt.patch_size, self.opt.patch_size)) + #self.visdom.display_current_results({'before': before, + # 'after': patches}, 0, save_result=False) + return patches.view(B, ntiles, C, H, W) + + def sample_patches_old(self, img, indices): + B, C, H, W = img.size() + s = self.opt.patch_size + if H % s > 0 or W % s > 0: + y_offset = torch.randint(H % s, (), device=img.device) + x_offset = torch.randint(W % s, (), device=img.device) + img = img[:, :, + y_offset:y_offset + s * (H // s), + x_offset:x_offset + s * (W // s)] + img = img.view(B, C, H//s, s, W//s, s) + ntiles = (H // s) * (W // s) + tiles = img.permute(0, 2, 4, 1, 3, 5).reshape(B, ntiles, C, s, s) + if indices is None: + indices = torch.randperm(ntiles, device=img.device)[:self.opt.max_num_tiles] + return self.apply_random_transformation(tiles[:, indices]), indices + else: + return self.apply_random_transformation(tiles[:, indices]) + + def forward(self, real, fake, fake_only=False): + assert real is not None + real_patches, patch_ids = self.sample_patches(real, None) + if fake is None: + real_patches.requires_grad_() + real_feat = self.extract_features(real_patches) + + bs = real.size(0) + if fake is None or not fake_only: + pred_real = self.discriminate_features( + real_feat, + torch.roll(real_feat, 1, 1)) + pred_real = pred_real.view(bs, -1) + + + if fake is not None: + fake_patches = self.sample_patches(fake, patch_ids) + #self.visualizer.display_current_results({'real_A': real_patches[0], + # 'real_B': torch.roll(fake_patches, 1, 1)[0]}, 0, False, max_num_images=16) + fake_feat = self.extract_features(fake_patches) + pred_fake = self.discriminate_features( + real_feat, + torch.roll(fake_feat, 1, 1)) + pred_fake = pred_fake.view(bs, -1) + + if fake is None: + return pred_real, real_patches + elif fake_only: + return pred_fake + else: + return pred_real, pred_fake + + + +class StyleGAN2PatchDiscriminator(BasePatchDiscriminator): + @staticmethod + def modify_commandline_options(parser, is_train): + BasePatchDiscriminator.modify_commandline_options(parser, is_train) + return parser + + def __init__(self, opt): + super().__init__(opt) + channel_multiplier = self.opt.netPatchD_scale_capacity + size = self.opt.patch_size + channels = { + 4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)), + 8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)), + 16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)), + 32: int(32 * channel_multiplier), + 64: int(16 * channel_multiplier), + 128: int(8 * channel_multiplier), + 256: int(4 * channel_multiplier), + } + + log_size = int(math.ceil(math.log(size, 2))) + + in_channel = channels[2 ** log_size] + + blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1] + + convs = [('0', ConvLayer(3, in_channel, 3))] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i) + convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) + + in_channel = out_channel + + convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False))) + convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0))) + + self.convs = nn.Sequential(OrderedDict(convs)) + + out_dim = 1 + + pairlinear1 = EqualLinear(channels[4] * 2 * 2, 2048, activation='fused_lrelu') + pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu') + pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu') + pairlinear4 = EqualLinear(1024, out_dim) + self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4) + + def extract_features(self, patches, aggregate=False): + if patches.ndim == 5: + B, T, C, H, W = patches.size() + flattened_patches = patches.flatten(0, 1) + else: + B, C, H, W = patches.size() + T = patches.size(1) + flattened_patches = patches + features = self.convs(flattened_patches) + features = features.view(B, T, features.size(1), features.size(2), features.size(3)) + if aggregate: + features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1) + return features.flatten(0, 1) + + def extract_layerwise_features(self, image): + feats = [image] + for m in self.convs: + feats.append(m(feats[-1])) + + return feats + + def discriminate_features(self, feature1): + feature1 = feature1.flatten(1) + #feature2 = feature2.flatten(1) + #out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) + out = self.pairlinear(feature1) + return out + """ + def discriminate_features(self, feature1, feature2): + feature1 = feature1.flatten(1) + feature2 = feature2.flatten(1) + out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) + return out + """ + +class StyleGAN2COGANPatchDiscriminator(BasePatchDiscriminator): + @staticmethod + def modify_commandline_options(parser, is_train): + BasePatchDiscriminator.modify_commandline_options(parser, is_train) + return parser + + def __init__(self, opt): + super().__init__(opt) + channel_multiplier = self.opt.netPatchD_scale_capacity + size = self.opt.patch_size + channels = { + 4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)), + 8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)), + 16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)), + 32: int(32 * channel_multiplier), + 64: int(16 * channel_multiplier), + 128: int(8 * channel_multiplier), + 256: int(4 * channel_multiplier), + } + + log_size = int(math.ceil(math.log(size, 2))) + + in_channel = channels[2 ** log_size] + + blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1] + + convs = [('0', ConvLayer(3, in_channel, 3))] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i) + convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) + + in_channel = out_channel + + convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False))) + convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0))) + + self.convs = nn.Sequential(OrderedDict(convs)) + + out_dim = 1 + + pairlinear1 = EqualLinear(channels[4] * 2 * 2 * 2, 2048, activation='fused_lrelu') + pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu') + pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu') + pairlinear4 = EqualLinear(1024, out_dim) + self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4) + + def extract_features(self, patches, aggregate=False): + if patches.ndim == 5: + B, T, C, H, W = patches.size() + flattened_patches = patches.flatten(0, 1) + else: + B, C, H, W = patches.size() + T = patches.size(1) + flattened_patches = patches + features = self.convs(flattened_patches) + features = features.view(B, T, features.size(1), features.size(2), features.size(3)) + if aggregate: + features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1) + return features.flatten(0, 1) + + def extract_layerwise_features(self, image): + feats = [image] + for m in self.convs: + feats.append(m(feats[-1])) + + return feats + + def discriminate_features(self, feature1, feature2): + feature1 = feature1.flatten(1) + feature2 = feature2.flatten(1) + out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) + return out diff --git a/swapae/models/networks/pyramidnet.py b/swapae/models/networks/pyramidnet.py new file mode 100644 index 0000000000000000000000000000000000000000..277a9ccea936baed3a81a6e9bc59c297162ced9f --- /dev/null +++ b/swapae/models/networks/pyramidnet.py @@ -0,0 +1,229 @@ +# Original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/PyramidNet.py + +import torch +import torch.nn as nn +import math + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + outchannel_ratio = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + shortcut = self.downsample(x) + featuremap_size = shortcut.size()[2:4] + else: + shortcut = x + featuremap_size = out.size()[2:4] + + batch_size = out.size()[0] + residual_channel = out.size()[1] + shortcut_channel = shortcut.size()[1] + + if residual_channel != shortcut_channel: + padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) + out += torch.cat((shortcut, padding), 1) + else: + out += shortcut + + return out + + +class Bottleneck(nn.Module): + outchannel_ratio = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): + super(Bottleneck, self).__init__() + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False, groups=1) + self.bn3 = nn.BatchNorm2d((planes)) + self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) + self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) + self.relu = nn.ReLU(inplace=True) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + + out = self.bn1(x) + out = self.conv1(out) + + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + + out = self.bn4(out) + if self.downsample is not None: + shortcut = self.downsample(x) + featuremap_size = shortcut.size()[2:4] + else: + shortcut = x + featuremap_size = out.size()[2:4] + + batch_size = out.size()[0] + residual_channel = out.size()[1] + shortcut_channel = shortcut.size()[1] + + if residual_channel != shortcut_channel: + padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) + out += torch.cat((shortcut, padding), 1) + else: + out += shortcut + + return out + + +class PyramidNet(nn.Module): + + def __init__(self, dataset, depth, alpha, num_classes, bottleneck=False): + super(PyramidNet, self).__init__() + self.dataset = dataset + if self.dataset.startswith('cifar'): + self.inplanes = 16 + if bottleneck == True: + n = int((depth - 2) / 9) + block = Bottleneck + else: + n = int((depth - 2) / 6) + block = BasicBlock + + self.addrate = alpha / (3*n*1.0) + + self.input_featuremap_dim = self.inplanes + self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) + + self.featuremap_dim = self.input_featuremap_dim + self.layer1 = self.pyramidal_make_layer(block, n) + self.layer2 = self.pyramidal_make_layer(block, n, stride=2) + self.layer3 = self.pyramidal_make_layer(block, n, stride=2) + + self.final_featuremap_dim = self.input_featuremap_dim + self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) + self.relu_final = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(8) + self.fc = nn.Linear(self.final_featuremap_dim, num_classes) + + elif dataset == 'imagenet': + blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} + layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} + + if layers.get(depth) is None: + if bottleneck == True: + blocks[depth] = Bottleneck + temp_cfg = int((depth-2)/12) + else: + blocks[depth] = BasicBlock + temp_cfg = int((depth-2)/8) + + layers[depth]= [temp_cfg, temp_cfg, temp_cfg, temp_cfg] + print('=> the layer configuration for each stage is set to', layers[depth]) + + self.inplanes = 64 + self.addrate = alpha / (sum(layers[depth])*1.0) + + self.input_featuremap_dim = self.inplanes + self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.featuremap_dim = self.input_featuremap_dim + self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) + self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) + self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) + self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) + + self.final_featuremap_dim = self.input_featuremap_dim + self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) + self.relu_final = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(self.final_featuremap_dim, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def pyramidal_make_layer(self, block, block_depth, stride=1): + downsample = None + if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: + downsample = nn.AvgPool2d((2,2), stride = (2, 2), ceil_mode=True) + + layers = [] + self.featuremap_dim = self.featuremap_dim + self.addrate + layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample)) + for i in range(1, block_depth): + temp_featuremap_dim = self.featuremap_dim + self.addrate + layers.append(block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1)) + self.featuremap_dim = temp_featuremap_dim + self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio + + return nn.Sequential(*layers) + + def forward(self, x): + if self.dataset == 'cifar10' or self.dataset == 'cifar100': + x = self.conv1(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.bn_final(x) + x = self.relu_final(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + elif self.dataset == 'imagenet': + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.bn_final(x) + x = self.relu_final(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/swapae/models/networks/stylegan2_layers.py b/swapae/models/networks/stylegan2_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..4f34673a287b8bf4ed7ebcf3a22ae3c58d6c79ff --- /dev/null +++ b/swapae/models/networks/stylegan2_layers.py @@ -0,0 +1,765 @@ +############################################################## +# from https://github.com/rosinality/stylegan2-pytorch +############################################################## +from collections import OrderedDict +import math +import random +import functools +import operator + +import torch +import models +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + +from swapae.models.networks.stylegan2_op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.dim() == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2, pad=None, reflection_pad=False): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + self.reflection = reflection_pad + + if pad is None: + p = kernel.shape[0] - factor + else: + p = pad + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + if self.reflection: + input = F.pad(input, (self.pad[0], self.pad[1], self.pad[0], self.pad[1]), mode='reflect') + pad = (0, 0) + else: + pad = self.pad + + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1, reflection_pad=False): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + self.reflection = reflection_pad + if self.reflection: + self.reflection_pad = nn.ReflectionPad2d((pad[0], pad[1], pad[0], pad[1])) + self.pad = (0, 0) + + def forward(self, input): + if self.reflection: + input = self.reflection_pad(input) + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, lr_mul=1.0, + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) * lr_mul + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + if input.dim() > 2: + out = F.conv2d(input, self.weight[:, :, None, None] * self.scale) + else: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + if input.dim() > 2: + out = F.conv2d(input, self.weight[:, :, None, None] * self.scale, + bias=self.bias * self.lr_mul + ) + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + self.new_demodulation = True + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + if style.dim() > 2: + style = F.interpolate(style, size=(input.size(2), input.size(3)), mode='bilinear', align_corners=False) + #style = self.modulation(style).unsqueeze(1) + style = self.modulation(style) + if self.demodulate: + style = style * torch.rsqrt(style.pow(2).mean([2], keepdim=True) + 1e-8) + input = input * style + weight = self.scale * self.weight + weight = weight.repeat(batch, 1, 1, 1, 1) + else: + style = style.view(batch, style.size(1)) + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + if self.new_demodulation: + style = style[:, 0, :, :, :] + if self.demodulate: + style = style * torch.rsqrt(style.pow(2).mean([1], keepdim=True) + 1e-8) + input = input * style + weight = self.scale * self.weight + weight = weight.repeat(batch, 1, 1, 1, 1) + else: + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + self.fixed_noise = None + self.image_size = None + + def forward(self, image, noise=None): + if self.image_size is None: + self.image_size = image.shape + + if noise is None and self.fixed_noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + elif self.fixed_noise is not None: + noise = self.fixed_noise + # to avoid error when generating thumbnails in demo + if image.size(2) != noise.size(2) or image.size(3) != noise.size(3): + noise = F.interpolate(noise, image.shape[2:], mode="nearest") + else: + pass # use the passed noise + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + use_noise=True, + lr_mul=1.0, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.use_noise = use_noise + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + if self.use_noise: + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].dim() < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + pad=None, + reflection_pad=False, + ): + layers = [] + + if downsample: + factor = 2 + if pad is None: + pad = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (pad + 1) // 2 + pad1 = pad // 2 + + layers.append(("Blur", Blur(blur_kernel, pad=(pad0, pad1), reflection_pad=reflection_pad))) + + stride = 2 + self.padding = 0 + else: + stride = 1 + self.padding = kernel_size // 2 if pad is None else pad + if reflection_pad: + layers.append(("RefPad", nn.ReflectionPad2d(self.padding))) + self.padding = 0 + + + layers.append(("Conv", + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + )) + ) + + if activate: + if bias: + layers.append(("Act", FusedLeakyReLU(out_channel))) + + else: + layers.append(("Act", ScaledLeakyReLU(0.2))) + + super().__init__(OrderedDict(layers)) + + def forward(self, x): + out = super().forward(x) + return out + + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], reflection_pad=False, pad=None, downsample=True): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3, reflection_pad=reflection_pad, pad=pad) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel, reflection_pad=reflection_pad, pad=pad) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=downsample, blur_kernel=blur_kernel, activate=False, bias=False + ) + + def forward(self, input): + #print("before first resnet layeer, ", input.shape) + out = self.conv1(input) + #print("after first resnet layer, ", out.shape) + out = self.conv2(out) + #print("after second resnet layer, ", out.shape) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: min(512, int(512 * channel_multiplier)), + 32: min(512, int(512 * channel_multiplier)), + 64: int(256 * channel_multiplier), + 128: int(128 * channel_multiplier), + 256: int(64 * channel_multiplier), + 512: int(32 * channel_multiplier), + 1024: int(16 * channel_multiplier), + } + + original_size = size + + size = 2 ** int(round(math.log(size, 2))) + + convs = [('0', ConvLayer(3, channels[size], 1))] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + layer_name = str(9 - i) if i <= 8 else "%dx%d" % (2 ** i, 2 ** i) + convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) + + in_channel = out_channel + + self.convs = nn.Sequential(OrderedDict(convs)) + + #self.stddev_group = 4 + #self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel, channels[4], 3) + + side_length = int(4 * original_size / size) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * (side_length ** 2), channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + #group = min(batch, self.stddev_group) + #stddev = out.view( + # group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + #) + #stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + #stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + #stddev = stddev.repeat(group, 1, height, width) + #out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + + def get_features(self, input): + return self.final_conv(self.convs(input)) + diff --git a/swapae/models/networks/stylegan2_op/__init__.py b/swapae/models/networks/stylegan2_op/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/swapae/models/networks/stylegan2_op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-37.pyc b/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a88fe03f495fc3cb989bbfb3690f18f17b16d82 Binary files /dev/null and b/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-37.pyc differ diff --git a/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc b/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001600a789ba87ef982d44992bea3c98dd51f917 Binary files /dev/null and b/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc differ diff --git a/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-37.pyc b/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ade0c9b01ccff5f24865e1f2ecde76bc6f07041c Binary files /dev/null and b/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-37.pyc differ diff --git a/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc b/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01c53826bc5010f5827b22861b5b7e9403e97e43 Binary files /dev/null and b/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc differ diff --git a/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-37.pyc b/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..199d05cfcad8dac8c41dc649b9bb8a458c7640e4 Binary files /dev/null and b/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-37.pyc differ diff --git a/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc b/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3215772fc06eb063fcb00683b2fd59501f82c8 Binary files /dev/null and b/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc differ diff --git a/swapae/models/networks/stylegan2_op/fused_act.py b/swapae/models/networks/stylegan2_op/fused_act.py new file mode 100755 index 0000000000000000000000000000000000000000..5d34a8ce951c5dc698a8aeb090318ba87cf00c02 --- /dev/null +++ b/swapae/models/networks/stylegan2_op/fused_act.py @@ -0,0 +1,99 @@ +import os + +import torch +from torch import nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load +from swapae.util import is_custom_kernel_supported as is_custom_kernel_supported + +""" +if is_custom_kernel_supported(): + module_path = os.path.dirname(__file__) + fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], + verbose=True, + ) + +use_custom_kernel = is_custom_kernel_supported() +""" +use_custom_kernel = False + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + global use_custom_kernel + if use_custom_kernel: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) + else: + dims = [1, -1] + [1] * (input.dim() - 2) + bias = bias.view(*dims) + return F.leaky_relu(input + bias, negative_slope) * scale diff --git a/swapae/models/networks/stylegan2_op/fused_bias_act.cpp b/swapae/models/networks/stylegan2_op/fused_bias_act.cpp new file mode 100755 index 0000000000000000000000000000000000000000..a054318781a20596d8f516ef86745e5572aad0f7 --- /dev/null +++ b/swapae/models/networks/stylegan2_op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/swapae/models/networks/stylegan2_op/fused_bias_act_kernel.cu b/swapae/models/networks/stylegan2_op/fused_bias_act_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..8d2f03c73605faee6723d002ba5de88cb465a80e --- /dev/null +++ b/swapae/models/networks/stylegan2_op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/swapae/models/networks/stylegan2_op/upfirdn2d.cpp b/swapae/models/networks/stylegan2_op/upfirdn2d.cpp new file mode 100755 index 0000000000000000000000000000000000000000..86472927281a5e057e430bdb32f39d623c29a8bb --- /dev/null +++ b/swapae/models/networks/stylegan2_op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/swapae/models/networks/stylegan2_op/upfirdn2d.py b/swapae/models/networks/stylegan2_op/upfirdn2d.py new file mode 100755 index 0000000000000000000000000000000000000000..237417e82cda995ebf61f208cb7012fb2869033f --- /dev/null +++ b/swapae/models/networks/stylegan2_op/upfirdn2d.py @@ -0,0 +1,225 @@ +import os + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load +from swapae.util import is_custom_kernel_supported as is_custom_kernel_supported + +""" +if is_custom_kernel_supported(): + print("Loading custom kernel...") + module_path = os.path.dirname(__file__) + upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], + verbose=True + ) + +use_custom_kernel = is_custom_kernel_supported() +""" +use_custom_kernel = False + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + global use_custom_kernel + if use_custom_kernel: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + else: + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + bs, ch, in_h, in_w = input.shape + minor = 1 + kernel_h, kernel_w = kernel.shape + + #assert kernel_h == 1 and kernel_w == 1 + + #print("original shape ", input.shape, up_x, down_x, pad_x0, pad_x1) + + out = input.view(-1, in_h, 1, in_w, 1, minor) + if up_x > 1 or up_y > 1: + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + + #print("after padding ", out.shape) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + #print("after reshaping ", out.shape) + + if pad_x0 > 0 or pad_x1 > 0 or pad_y0 > 0 or pad_y1 > 0: + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + + #print("after second padding ", out.shape) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + #print("after trimming ", out.shape) + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + + #print("after reshaping", out.shape) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + + #print("after conv ", out.shape) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + + out = out.permute(0, 2, 3, 1) + + #print("after permuting ", out.shape) + + out = out[:, ::down_y, ::down_x, :] + + out = out.view(bs, ch, out.size(1), out.size(2)) + + #print("final shape ", out.shape) + + return out diff --git a/swapae/models/networks/stylegan2_op/upfirdn2d_kernel.cu b/swapae/models/networks/stylegan2_op/upfirdn2d_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..871d4fe2fafb6c7863ea41656f8770f8a4a61b3a --- /dev/null +++ b/swapae/models/networks/stylegan2_op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/swapae/models/swapping_autoencoder_model.py b/swapae/models/swapping_autoencoder_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f093ab16591b457717a9c77c349ff84c67a68139 --- /dev/null +++ b/swapae/models/swapping_autoencoder_model.py @@ -0,0 +1,275 @@ +import torch +import swapae.util as util +from swapae.models import BaseModel +import swapae.models.networks as networks +import swapae.models.networks.loss as loss + + +class SwappingAutoencoderModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train): + BaseModel.modify_commandline_options(parser, is_train) + parser.add_argument("--spatial_code_ch", default=8, type=int) + parser.add_argument("--global_code_ch", default=2048, type=int) + parser.add_argument("--lambda_R1", default=10.0, type=float) + parser.add_argument("--lambda_patch_R1", default=1.0, type=float) + parser.add_argument("--lambda_L1", default=1.0, type=float) + parser.add_argument("--lambda_GAN", default=1.0, type=float) + parser.add_argument("--lambda_PatchGAN", default=1.0, type=float) + parser.add_argument("--patch_min_scale", default=1 / 8, type=float) + parser.add_argument("--patch_max_scale", default=1 / 4, type=float) + parser.add_argument("--patch_num_crops", default=8, type=int) + parser.add_argument("--patch_use_aggregation", + type=util.str2bool, default=True) + return parser + + def initialize(self): + self.E = networks.create_network(self.opt, self.opt.netE, "encoder") + self.G = networks.create_network(self.opt, self.opt.netG, "generator") + if self.opt.lambda_GAN > 0.0: + self.D = networks.create_network( + self.opt, self.opt.netD, "discriminator") + if self.opt.lambda_PatchGAN > 0.0: + self.Dpatch = networks.create_network( + self.opt, self.opt.netPatchD, "patch_discriminator" + ) + + # Count the iteration count of the discriminator + # Used for lazy R1 regularization (c.f. Appendix B of StyleGAN2) + self.register_buffer( + "num_discriminator_iters", torch.zeros(1, dtype=torch.long) + ) + self.l1_loss = torch.nn.L1Loss() + + if (not self.opt.isTrain) or self.opt.continue_train: + self.load() + + if self.opt.num_gpus > 0: + self.to("cuda:0") + + def per_gpu_initialize(self): + pass + + def swap(self, x): + """ Swaps (or mixes) the ordering of the minibatch """ + shape = x.shape + assert shape[0] % 2 == 0, "Minibatch size must be a multiple of 2" + new_shape = [shape[0] // 2, 2] + list(shape[1:]) + x = x.view(*new_shape) + x = torch.flip(x, [1]) + return x.view(*shape) + + def compute_image_discriminator_losses(self, real, rec, mix): + if self.opt.lambda_GAN == 0.0: + return {} + + pred_real = self.D(real) + pred_rec = self.D(rec) + pred_mix = self.D(mix) + + losses = {} + losses["D_real"] = loss.gan_loss( + pred_real, should_be_classified_as_real=True + ) * self.opt.lambda_GAN + + losses["D_rec"] = loss.gan_loss( + pred_rec, should_be_classified_as_real=False + ) * (0.5 * self.opt.lambda_GAN) + losses["D_mix"] = loss.gan_loss( + pred_mix, should_be_classified_as_real=False + ) * (0.5 * self.opt.lambda_GAN) + + return losses + + def get_random_crops(self, x, crop_window=None): + """ Make random crops. + Corresponds to the yellow and blue random crops of Figure 2. + """ + crops = util.apply_random_crop( + x, self.opt.patch_size, + (self.opt.patch_min_scale, self.opt.patch_max_scale), + num_crops=self.opt.patch_num_crops + ) + return crops + + def compute_patch_discriminator_losses(self, real, mix): + losses = {} + real_feat = self.Dpatch.extract_features( + self.get_random_crops(real), + aggregate=self.opt.patch_use_aggregation + ) + target_feat = self.Dpatch.extract_features(self.get_random_crops(real)) + mix_feat = self.Dpatch.extract_features(self.get_random_crops(mix)) + + losses["PatchD_real"] = loss.gan_loss( + self.Dpatch.discriminate_features(real_feat, target_feat), + should_be_classified_as_real=True, + ) * self.opt.lambda_PatchGAN + + losses["PatchD_mix"] = loss.gan_loss( + self.Dpatch.discriminate_features(real_feat, mix_feat), + should_be_classified_as_real=False, + ) * self.opt.lambda_PatchGAN + + return losses + + def compute_discriminator_losses(self, real): + self.num_discriminator_iters.add_(1) + + sp, gl = self.E(real) + B = real.size(0) + assert B % 2 == 0, "Batch size must be even on each GPU." + + # To save memory, compute the GAN loss on only + # half of the reconstructed images + rec = self.G(sp[:B // 2], gl[:B // 2]) + mix = self.G(self.swap(sp), gl) + + losses = self.compute_image_discriminator_losses(real, rec, mix) + + if self.opt.lambda_PatchGAN > 0.0: + patch_losses = self.compute_patch_discriminator_losses(real, mix) + losses.update(patch_losses) + + metrics = {} # no metrics to report for the Discriminator iteration + + return losses, metrics, sp.detach(), gl.detach() + + def compute_R1_loss(self, real): + losses = {} + if self.opt.lambda_R1 > 0.0: + real.requires_grad_() + pred_real = self.D(real).sum() + grad_real, = torch.autograd.grad( + outputs=pred_real, + inputs=[real], + create_graph=True, + retain_graph=True, + ) + grad_real2 = grad_real.pow(2) + dims = list(range(1, grad_real2.ndim)) + grad_penalty = grad_real2.sum(dims) * (self.opt.lambda_R1 * 0.5) + else: + grad_penalty = 0.0 + + if self.opt.lambda_patch_R1 > 0.0: + real_crop = self.get_random_crops(real).detach() + real_crop.requires_grad_() + target_crop = self.get_random_crops(real).detach() + target_crop.requires_grad_() + + real_feat = self.Dpatch.extract_features( + real_crop, + aggregate=self.opt.patch_use_aggregation) + target_feat = self.Dpatch.extract_features(target_crop) + pred_real_patch = self.Dpatch.discriminate_features( + real_feat, target_feat + ).sum() + + grad_real, grad_target = torch.autograd.grad( + outputs=pred_real_patch, + inputs=[real_crop, target_crop], + create_graph=True, + retain_graph=True, + ) + + dims = list(range(1, grad_real.ndim)) + grad_crop_penalty = grad_real.pow(2).sum(dims) + \ + grad_target.pow(2).sum(dims) + grad_crop_penalty *= (0.5 * self.opt.lambda_patch_R1 * 0.5) + else: + grad_crop_penalty = 0.0 + + losses["D_R1"] = grad_penalty + grad_crop_penalty + + return losses + + def compute_generator_losses(self, real, sp_ma, gl_ma): + losses, metrics = {}, {} + B = real.size(0) + + sp, gl = self.E(real) + rec = self.G(sp[:B // 2], gl[:B // 2]) # only on B//2 to save memory + sp_mix = self.swap(sp) + + if self.opt.crop_size >= 1024: + # another momery-saving trick: reduce #outputs to save memory + real = real[B // 2:] + gl = gl[B // 2:] + sp_mix = sp_mix[B // 2:] + + mix = self.G(sp_mix, gl) + + # record the error of the reconstructed images for monitoring purposes + metrics["L1_dist"] = self.l1_loss(rec, real[:B // 2]) + + if self.opt.lambda_L1 > 0.0: + losses["G_L1"] = metrics["L1_dist"] * self.opt.lambda_L1 + + if self.opt.lambda_GAN > 0.0: + losses["G_GAN_rec"] = loss.gan_loss( + self.D(rec), + should_be_classified_as_real=True + ) * (self.opt.lambda_GAN * 0.5) + + losses["G_GAN_mix"] = loss.gan_loss( + self.D(mix), + should_be_classified_as_real=True + ) * (self.opt.lambda_GAN * 1.0) + + if self.opt.lambda_PatchGAN > 0.0: + real_feat = self.Dpatch.extract_features( + self.get_random_crops(real), + aggregate=self.opt.patch_use_aggregation).detach() + mix_feat = self.Dpatch.extract_features(self.get_random_crops(mix)) + + losses["G_mix"] = loss.gan_loss( + self.Dpatch.discriminate_features(real_feat, mix_feat), + should_be_classified_as_real=True, + ) * self.opt.lambda_PatchGAN + + return losses, metrics + + def get_visuals_for_snapshot(self, real): + if self.opt.isTrain: + # avoid the overhead of generating too many visuals during training + real = real[:2] if self.opt.num_gpus > 1 else real[:4] + sp, gl = self.E(real) + layout = util.resize2d_tensor(util.visualize_spatial_code(sp), real) + rec = self.G(sp, gl) + mix = self.G(sp, self.swap(gl)) + + visuals = {"real": real, "layout": layout, "rec": rec, "mix": mix} + + return visuals + + def fix_noise(self, sample_image=None): + """ The generator architecture is stochastic because of the noise + input at each layer (StyleGAN2 architecture). It could lead to + flickering of the outputs even when identical inputs are given. + Prevent flickering by fixing the noise injection of the generator. + """ + if sample_image is not None: + # The generator should be run at least once, + # so that the noise dimensions could be computed + sp, gl = self.E(sample_image) + self.G(sp, gl) + noise_var = self.G.fix_and_gather_noise_parameters() + return noise_var + + def encode(self, image, extract_features=False): + return self.E(image, extract_features=extract_features) + + def decode(self, spatial_code, global_code): + return self.G(spatial_code, global_code) + + def get_parameters_for_mode(self, mode): + if mode == "generator": + return list(self.G.parameters()) + list(self.E.parameters()) + elif mode == "discriminator": + Dparams = [] + if self.opt.lambda_GAN > 0.0: + Dparams += list(self.D.parameters()) + if self.opt.lambda_PatchGAN > 0.0: + Dparams += list(self.Dpatch.parameters()) + return Dparams diff --git a/swapae/optimizers/__init__.py b/swapae/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d374cce6b644e7776a5d9eafc41d0a1f8c943dc --- /dev/null +++ b/swapae/optimizers/__init__.py @@ -0,0 +1,48 @@ +import os +import importlib +from swapae.optimizers.base_optimizer import BaseOptimizer +import torch + + +def find_optimizer_using_name(optimizer_name): + """Import the module "optimizers/[optimizer_name]_optimizer.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseOptimizer, + and it is case-insensitive. + """ + optimizer_filename = "swapae.optimizers." + optimizer_name + "_optimizer" + optimizerlib = importlib.import_module(optimizer_filename) + optimizer = None + target_optimizer_name = optimizer_name.replace('_', '') + 'optimizer' + for name, cls in optimizerlib.__dict__.items(): + if name.lower() == target_optimizer_name.lower() \ + and issubclass(cls, BaseOptimizer): + optimizer = cls + + if optimizer is None: + print("In %s.py, there should be a subclass of BaseOptimizer with class name that matches %s in lowercase." % (optimizer_filename, target_optimizer_name)) + exit(0) + + return optimizer + + +def get_option_setter(optimizer_name): + """Return the static method of the optimizer class.""" + optimizer_class = find_optimizer_using_name(optimizer_name) + return optimizer_class.modify_commandline_options + + +def create_optimizer(opt, model): + """Create a optimizer given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from optimizers import create_optimizer + >>> optimizer = create_optimizer(opt) + """ + optimizer = find_optimizer_using_name(opt.optimizer) + instance = optimizer(model) + return instance diff --git a/swapae/optimizers/base_optimizer.py b/swapae/optimizers/base_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..37fee6eb938222c888f49a9adcbceb184e2bbcd7 --- /dev/null +++ b/swapae/optimizers/base_optimizer.py @@ -0,0 +1,19 @@ +from swapae.models import MultiGPUModelWrapper + + +class BaseOptimizer(): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def __init__(self, model: MultiGPUModelWrapper): + self.opt = model.opt + + def train_one_step(self, data_i, total_steps_so_far): + pass + + def get_visuals_for_snapshot(self, data_i): + return {} + + def save(self, total_steps_so_far): + pass diff --git a/swapae/optimizers/classifier_optimizer.py b/swapae/optimizers/classifier_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..31a3f3533afcb7c129c60621378efb18a625b9a9 --- /dev/null +++ b/swapae/optimizers/classifier_optimizer.py @@ -0,0 +1,35 @@ +import torch +from models import MultiGPUModelWrapper +from swapae.optimizers.swapping_autoencoder_optimizer import SwappingAutoencoderOptimizer +import swapae.util + + +class ClassifierOptimizer(SwappingAutoencoderOptimizer): + @staticmethod + def modify_commandline_options(parser, is_train): + parser = SwappingAutoencoderOptimizer.modify_commandline_options(parser, is_train) + return parser + + def train_one_step(self, data_i, total_steps_so_far): + images_minibatch, labels = self.prepare_images(data_i) + c_losses = self.train_classifier_one_step(images_minibatch, labels) + self.adjust_lr_if_necessary(total_steps_so_far) + return util.to_numpy(c_losses) + + def train_classifier_one_step(self, images, labels): + self.set_requires_grad(self.Gparams, False) + self.optimizer_C.zero_grad() + losses, metrics = self.model(images, labels, command="compute_classifier_losses") + loss = sum([v.mean() for v in losses.values()]) + loss.backward() + self.optimizer_C.step() + losses.update(metrics) + return losses + + def get_visuals_for_snapshot(self, data_i): + images, labels = self.prepare_images(data_i) + with torch.no_grad(): + return self.model(images, labels, command="get_visuals_for_snapshot") + + def save(self, total_steps_so_far): + self.model.save(total_steps_so_far) diff --git a/swapae/optimizers/patchD_autoencoder_optimizer.py b/swapae/optimizers/patchD_autoencoder_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7dca8b4919d8634abaa76393a918509518384dbc --- /dev/null +++ b/swapae/optimizers/patchD_autoencoder_optimizer.py @@ -0,0 +1,114 @@ +import torch +from models import MultiGPUModelWrapper +from swapae.optimizers.base_optimizer import BaseOptimizer +import swapae.util + + +class PatchDAutoencoderOptimizer(BaseOptimizer): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--lr", default=0.002, type=float) + parser.add_argument("--beta1", default=0.0, type=float) + parser.add_argument("--beta2", default=0.99, type=float) + parser.add_argument("--R1_once_every", default=16, type=int, + help="lazy R1 regularization. R1 loss is computed once in 1/R1_freq times") + + return parser + + def __init__(self, model: MultiGPUModelWrapper): + self.opt = model.opt + opt = self.opt + self.model = model + self.training_mode_index = 0 + + self.Gparams = self.model.get_parameters_for_mode("generator") + self.Dparams = self.model.get_parameters_for_mode("discriminator") + + self.num_discriminator_iters = 0 + self.optimizer_G = torch.optim.Adam(self.Gparams, lr=opt.lr, + betas=(opt.beta1, opt.beta2)) + # StyleGAN2 Appendix B + c = opt.R1_once_every / (1 + opt.R1_once_every) + self.optimizer_D = torch.optim.Adam(self.Dparams, + lr=opt.lr * c, + betas=(opt.beta1 ** c, + opt.beta2 ** c)) + + def set_requires_grad(self, params, requires_grad): + for p in params: + p.requires_grad_(requires_grad) + + def prepare_images(self, data_i): + A = data_i["real_A"] + if "real_B" in data_i: + B = data_i["real_B"] + A = torch.cat([A, B], dim=0) + A = A[torch.randperm(A.size(0))] + return A + + def toggle_training_mode(self): + all_modes = ["generator", "discriminator"] + self.training_mode_index = (self.training_mode_index + 1) % len(all_modes) + return all_modes[self.training_mode_index] + + def train_one_step(self, data_i, total_steps_so_far): + images_minibatch = self.prepare_images(data_i) + if self.toggle_training_mode() == "generator": + losses = self.train_discriminator_one_step(images_minibatch) + else: + losses = self.train_generator_one_step(images_minibatch) + return util.to_numpy(losses) + + def train_generator_one_step(self, images): + self.set_requires_grad(self.Dparams, False) + self.set_requires_grad(self.Gparams, True) + _, gl_ma = self.model(images, command="encode", + use_momentum_encoder=True) + self.optimizer_G.zero_grad() + g_losses, g_metrics = self.model(images, gl_ma, + command="compute_generator_losses") + g_loss = sum([v.mean() for v in g_losses.values()]) + g_loss.backward() + self.optimizer_G.step() + g_losses.update(g_metrics) + return g_losses + + def train_discriminator_one_step(self, images): + if self.opt.lambda_GAN == 0.0 and self.opt.lambda_PatchGAN == 0.0: + return {} + self.set_requires_grad(self.Dparams, True) + self.set_requires_grad(self.Gparams, False) + self.num_discriminator_iters += 1 + self.optimizer_D.zero_grad() + + d_losses, d_metrics, features = self.model(images, + command="compute_discriminator_losses") + nce_losses, nce_metrics = self.model.singlegpu_model(*features, + command="compute_discriminator_nce_losses") + d_losses.update(nce_losses) + d_metrics.update(nce_metrics) + d_loss = sum([v.mean() for v in d_losses.values()]) + d_loss.backward() + self.optimizer_D.step() + needs_R1 = (self.opt.lambda_R1 > 0.0 or self.opt.lambda_patch_R1) and \ + (self.num_discriminator_iters % self.opt.R1_once_every == 0) + if needs_R1: + self.optimizer_D.zero_grad() + r1_losses = self.model(images, + command="compute_R1_loss") + d_losses.update(r1_losses) + r1_loss = sum([v.mean() for v in r1_losses.values()]) + r1_loss = r1_loss * self.opt.R1_once_every + r1_loss.backward() + self.optimizer_D.step() + + d_losses.update(d_metrics) + return d_losses + + def get_visuals_for_snapshot(self, data_i): + images = self.prepare_images(data_i) + with torch.no_grad(): + return self.model(images, command="get_visuals_for_snapshot") + + def save(self, total_steps_so_far): + self.model.save(total_steps_so_far) diff --git a/swapae/optimizers/swapping_autoencoder_optimizer.py b/swapae/optimizers/swapping_autoencoder_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3095eac2052c0ea3b8a63e99c756d9ae7da33398 --- /dev/null +++ b/swapae/optimizers/swapping_autoencoder_optimizer.py @@ -0,0 +1,119 @@ +import torch +import swapae.util as util +from swapae.models import MultiGPUModelWrapper +from swapae.optimizers.base_optimizer import BaseOptimizer + + +class SwappingAutoencoderOptimizer(BaseOptimizer): + """ Class for running the optimization of the model parameters. + Implements Generator / Discriminator training, R1 gradient penalty, + decaying learning rates, and reporting training progress. + """ + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--lr", default=0.002, type=float) + parser.add_argument("--beta1", default=0.0, type=float) + parser.add_argument("--beta2", default=0.99, type=float) + parser.add_argument( + "--R1_once_every", default=16, type=int, + help="lazy R1 regularization. R1 loss is computed " + "once in 1/R1_freq times", + ) + return parser + + def __init__(self, model: MultiGPUModelWrapper): + self.opt = model.opt + opt = self.opt + self.model = model + self.train_mode_counter = 0 + self.discriminator_iter_counter = 0 + + self.Gparams = self.model.get_parameters_for_mode("generator") + self.Dparams = self.model.get_parameters_for_mode("discriminator") + + self.optimizer_G = torch.optim.Adam( + self.Gparams, lr=opt.lr, betas=(opt.beta1, opt.beta2) + ) + + # c.f. StyleGAN2 (https://arxiv.org/abs/1912.04958) Appendix B + c = opt.R1_once_every / (1 + opt.R1_once_every) + self.optimizer_D = torch.optim.Adam( + self.Dparams, lr=opt.lr * c, betas=(opt.beta1 ** c, opt.beta2 ** c) + ) + + def set_requires_grad(self, params, requires_grad): + """ For more efficient optimization, turn on and off + recording of gradients for |params|. + """ + for p in params: + p.requires_grad_(requires_grad) + + def prepare_images(self, data_i): + return data_i["real_A"] + + def toggle_training_mode(self): + modes = ["discriminator", "generator"] + self.train_mode_counter = (self.train_mode_counter + 1) % len(modes) + return modes[self.train_mode_counter] + + def train_one_step(self, data_i, total_steps_so_far): + images_minibatch = self.prepare_images(data_i) + if self.toggle_training_mode() == "generator": + losses = self.train_discriminator_one_step(images_minibatch) + else: + losses = self.train_generator_one_step(images_minibatch) + return util.to_numpy(losses) + + def train_generator_one_step(self, images): + self.set_requires_grad(self.Dparams, False) + self.set_requires_grad(self.Gparams, True) + sp_ma, gl_ma = None, None + self.optimizer_G.zero_grad() + g_losses, g_metrics = self.model( + images, sp_ma, gl_ma, command="compute_generator_losses" + ) + g_loss = sum([v.mean() for v in g_losses.values()]) + g_loss.backward() + self.optimizer_G.step() + g_losses.update(g_metrics) + return g_losses + + def train_discriminator_one_step(self, images): + if self.opt.lambda_GAN == 0.0 and self.opt.lambda_PatchGAN == 0.0: + return {} + self.set_requires_grad(self.Dparams, True) + self.set_requires_grad(self.Gparams, False) + self.discriminator_iter_counter += 1 + self.optimizer_D.zero_grad() + d_losses, d_metrics, sp, gl = self.model( + images, command="compute_discriminator_losses" + ) + self.previous_sp = sp.detach() + self.previous_gl = gl.detach() + d_loss = sum([v.mean() for v in d_losses.values()]) + d_loss.backward() + self.optimizer_D.step() + + needs_R1 = self.opt.lambda_R1 > 0.0 or self.opt.lambda_patch_R1 > 0.0 + needs_R1_at_current_iter = needs_R1 and \ + self.discriminator_iter_counter % self.opt.R1_once_every == 0 + if needs_R1_at_current_iter: + self.optimizer_D.zero_grad() + r1_losses = self.model(images, command="compute_R1_loss") + d_losses.update(r1_losses) + r1_loss = sum([v.mean() for v in r1_losses.values()]) + r1_loss = r1_loss * self.opt.R1_once_every + r1_loss.backward() + self.optimizer_D.step() + + d_losses["D_total"] = sum([v.mean() for v in d_losses.values()]) + d_losses.update(d_metrics) + return d_losses + + def get_visuals_for_snapshot(self, data_i): + images = self.prepare_images(data_i) + with torch.no_grad(): + return self.model(images, command="get_visuals_for_snapshot") + + def save(self, total_steps_so_far): + self.model.save(total_steps_so_far) diff --git a/swapae/options/__init__.py b/swapae/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a0bfe84d3719d3c52cab7323e51e953e0600b2 --- /dev/null +++ b/swapae/options/__init__.py @@ -0,0 +1,212 @@ +import argparse +import shlex +import os +import pickle + +import swapae.util as util +import swapae.models as models +import swapae.models.networks as networks +import swapae.data as data +import swapae.evaluation as evaluation +import swapae.optimizers as optimizers +from swapae.util import IterationCounter +from swapae.util import Visualizer + + +class BaseOptions(): + def initialize(self, parser): + # experiment specifics + parser.add_argument('--name', type=str, default="ffhq512_pretrained", help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--easy_label', type=str, default="") + + parser.add_argument('--num_gpus', type=int, default=1, help='#GPUs to use. 0 means CPU mode') + parser.add_argument('--checkpoints_dir', type=str, default='/home/xtli/Documents/GITHUB/swapping-autoencoder-pytorch/checkpoints/', help='models are saved here') + parser.add_argument('--model', type=str, default='swapping_autoencoder', help='which model to use') + parser.add_argument('--optimizer', type=str, default='swapping_autoencoder', help='which model to use') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--resume_iter', type=str, default="latest", + help="# iterations (in thousands) to resume") + parser.add_argument('--num_classes', type=int, default=0) + + # input/output sizes + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--preprocess', type=str, default='resize', help='scaling and cropping of images at load time.') + parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.') + parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') + parser.add_argument('--preprocess_crop_padding', type=int, default=None, help='padding parameter of transforms.RandomCrop(). It is not used if --preprocess does not contain crop option.') + parser.add_argument('--no_flip', action='store_true') + parser.add_argument('--shuffle_dataset', type=str, default=None, choices=('true', 'false')) + + # for setting inputs + parser.add_argument('--dataroot', type=str, default="/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/") + parser.add_argument('--dataset_mode', type=str, default='imagefolder') + parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') + + # networks + parser.add_argument("--netG", default="StyleGAN2Resnet") + parser.add_argument("--netD", default="StyleGAN2") + parser.add_argument("--netE", default="StyleGAN2Resnet") + parser.add_argument("--netPatchD", default="StyleGAN2") + parser.add_argument("--use_antialias", type=util.str2bool, default=True) + + parser.add_argument("-f", "--config_file", type=str, default='models/swap/json/sem_cons.json', help='json files including all arguments') + parser.add_argument("--local_rank", type=int) + + return parser + + def gather_options(self, command=None): + parser = AugmentedArgumentParser() + parser.custom_command = command + + # get basic options + parser = self.initialize(parser) + + # get the basic options + opt, unknown = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + + # modify network-related parser options + parser = networks.modify_commandline_options(parser, self.isTrain) + + # modify optimizer-related parser options + optimizer_name = opt.optimizer + optimizer_option_setter = optimizers.get_option_setter(optimizer_name) + parser = optimizer_option_setter(parser, self.isTrain) + + # modify dataset-related parser options + dataset_mode = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_mode) + parser = dataset_option_setter(parser, self.isTrain) + + # modify parser options related to iteration_counting + parser = Visualizer.modify_commandline_options(parser, self.isTrain) + + # modify parser options related to iteration_counting + parser = IterationCounter.modify_commandline_options(parser, self.isTrain) + + # modify evaluation-related parser options + evaluation_option_setter = evaluation.get_option_setter() + parser = evaluation_option_setter(parser, self.isTrain) + + opt, unknown = parser.parse_known_args() + + opt = parser.parse_args() + self.parser = parser + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + def option_file_path(self, opt, makedir=False): + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + if makedir: + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt') + return file_name + + def save_options(self, opt): + file_name = self.option_file_path(opt, makedir=True) + with open(file_name + '.txt', 'wt') as opt_file: + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) + + with open(file_name + '.pkl', 'wb') as opt_file: + pickle.dump(opt, opt_file) + + def parse(self, save=False, command=None): + opt = self.gather_options(command) + opt.isTrain = self.isTrain # train or test + self.print_options(opt) + if opt.isTrain: + self.save_options(opt) + + opt.dataroot = os.path.expanduser(opt.dataroot) + + assert opt.num_gpus <= opt.batch_size, "Batch size must not be smaller than num_gpus" + return opt + + + +class TrainOptions(BaseOptions): + def __init__(self): + super().__init__() + self.isTrain = True + + def initialize(self, parser): + super().initialize(parser) + parser.add_argument('--continue_train', type=util.str2bool, default=False, help="resume training from last checkpoint") + parser.add_argument('--pretrained_name', type=str, default=None, + help="Load weights from the checkpoint of another experiment") + + return parser + + +class TestOptions(BaseOptions): + def __init__(self): + super().__init__() + self.isTrain = False + + def initialize(self, parser): + super().initialize(parser) + parser.add_argument("--result_dir", type=str, default="results") + return parser + + +class AugmentedArgumentParser(argparse.ArgumentParser): + def parse_args(self, args=None, namespace=None): + """ Enables passing bash commands as arguments to the class. + """ + print("parsing args...") + if args is None and hasattr(self, 'custom_command') and self.custom_command is not None: + print('using custom command') + print(self.custom_command) + args = shlex.split(self.custom_command)[2:] + return super().parse_args(args, namespace) + + def parse_known_args(self, args=None, namespace=None): + if args is None and hasattr(self, 'custom_command') and self.custom_command is not None: + args = shlex.split(self.custom_command)[2:] + return super().parse_known_args(args, namespace) + + def add_argument(self, *args, **kwargs): + """ Support for providing a new argument type called "str2bool" + + Example: + parser.add_argument("--my_option", type=util.str2bool, default=|bool|) + + 1. "python train.py" sets my_option to be |bool| + 2. "python train.py --my_option" sets my_option to be True + 3. "python train.py --my_option False" sets my_option to be False + 4. "python train.py --my_option True" sets my_options to be True + + https://stackoverflow.com/a/43357954 + """ + + if 'type' in kwargs and kwargs['type'] == util.str2bool: + if 'nargs' not in kwargs: + kwargs['nargs'] = "?" + if 'const' not in kwargs: + kwargs['const'] = True + super().add_argument(*args, **kwargs) diff --git a/swapae/test.py b/swapae/test.py new file mode 100644 index 0000000000000000000000000000000000000000..92befa3ea1fbecf912106f465e27049a8605e95e --- /dev/null +++ b/swapae/test.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +sys.path.append('./') + +import time +from swapae.options import TestOptions +import swapae.models as models +from swapae.evaluation import GroupEvaluator +import swapae.data as data +import torchvision.utils as vutils + +import torch +from PIL import Image +import torchvision.transforms as transforms + +opt = TestOptions().parse() +#dataset = data.create_dataset(opt) +#evaluators = GroupEvaluator(opt) + +model = models.create_model(opt) + +#evaluators.evaluate(model, dataset, opt.resume_iter) + +structure_path = '/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/structure/12000.png' +style_path = '/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/style/11104.png' + +structure_img = Image.open(structure_path) +style_img = Image.open(style_path) + +transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor()]) + +structure_img = transform(structure_img).unsqueeze(0) +style_img = transform(style_img).unsqueeze(0) +structure_img = structure_img * 2 - 1 +style_img = style_img * 2 - 1 + +s_time = time.time() +with torch.no_grad(): + structure_feat = model(structure_img, command="encode")[0] + style_feat = model(style_img, command="encode")[1] + rec = model(structure_feat, style_feat, command="decode") +e_time = time.time() +print(e_time - s_time) +rec = (rec + 1) / 2 +vutils.save_image(rec, 'rec.png') + diff --git a/swapae/util/__init__.py b/swapae/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c1cf4dbd1004f5ebae4f4b8ecebef15266a7b7 --- /dev/null +++ b/swapae/util/__init__.py @@ -0,0 +1,6 @@ +from .iter_counter import IterationCounter +from .visualizer import Visualizer +from .metric_tracker import MetricTracker +from .util import * +from .html import HTML +#from .pca import PCA diff --git a/swapae/util/__pycache__/__init__.cpython-37.pyc b/swapae/util/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a7355728bd25cfc0fbf6c3bf4fe7f34a9f367f5 Binary files /dev/null and b/swapae/util/__pycache__/__init__.cpython-37.pyc differ diff --git a/swapae/util/__pycache__/__init__.cpython-38.pyc b/swapae/util/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9290c56d1677c7d257a3a388d75d3aa2a4206f0 Binary files /dev/null and b/swapae/util/__pycache__/__init__.cpython-38.pyc differ diff --git a/swapae/util/__pycache__/html.cpython-37.pyc b/swapae/util/__pycache__/html.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a5d906e95ced3292fc7c837159e44921fd23248 Binary files /dev/null and b/swapae/util/__pycache__/html.cpython-37.pyc differ diff --git a/swapae/util/__pycache__/html.cpython-38.pyc b/swapae/util/__pycache__/html.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5673ab38a593375e25771caf4f4e53462b9484 Binary files /dev/null and b/swapae/util/__pycache__/html.cpython-38.pyc differ diff --git a/swapae/util/__pycache__/iter_counter.cpython-37.pyc b/swapae/util/__pycache__/iter_counter.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a16af132223394aa4dca828a5e2fee6877d644c4 Binary files /dev/null and b/swapae/util/__pycache__/iter_counter.cpython-37.pyc differ diff --git a/swapae/util/__pycache__/iter_counter.cpython-38.pyc b/swapae/util/__pycache__/iter_counter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cfe6e7fe26b80ca215c2da8ee05a75d99b1ab18 Binary files /dev/null and b/swapae/util/__pycache__/iter_counter.cpython-38.pyc differ diff --git a/swapae/util/__pycache__/metric_tracker.cpython-37.pyc b/swapae/util/__pycache__/metric_tracker.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690bb071fc6062b5f8ad7fd142f26a36bea8f5ac Binary files /dev/null and b/swapae/util/__pycache__/metric_tracker.cpython-37.pyc differ diff --git a/swapae/util/__pycache__/metric_tracker.cpython-38.pyc b/swapae/util/__pycache__/metric_tracker.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b2690864ba1b687df161228df207cc351ef4f9 Binary files /dev/null and b/swapae/util/__pycache__/metric_tracker.cpython-38.pyc differ diff --git a/swapae/util/__pycache__/util.cpython-37.pyc b/swapae/util/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..151d11151de0a729d74ca8ec5054bed478cd4e7d Binary files /dev/null and b/swapae/util/__pycache__/util.cpython-37.pyc differ diff --git a/swapae/util/__pycache__/util.cpython-38.pyc b/swapae/util/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..673f2d854f4768a68fee99b566d43cfe5a030aa9 Binary files /dev/null and b/swapae/util/__pycache__/util.cpython-38.pyc differ diff --git a/swapae/util/__pycache__/visualizer.cpython-37.pyc b/swapae/util/__pycache__/visualizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8811bfeeb64f224d5ed57f8d1729fc2617ce37fc Binary files /dev/null and b/swapae/util/__pycache__/visualizer.cpython-37.pyc differ diff --git a/swapae/util/__pycache__/visualizer.cpython-38.pyc b/swapae/util/__pycache__/visualizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0da3027d01e7afdd43f7829fe2a1cc422469a5c Binary files /dev/null and b/swapae/util/__pycache__/visualizer.cpython-38.pyc differ diff --git a/swapae/util/html.py b/swapae/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daf051767e141b655d673ba6ffc63068ff13a --- /dev/null +++ b/swapae/util/html.py @@ -0,0 +1,108 @@ +from PIL import Image +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + self.add_header(title) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links=None, width=None): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + if links is None: + links = ims + input_is_path = type(ims[0]) == str + if not input_is_path: # input is PIL Image + assert type(ims[0]) == Image.Image + paths = [] + for im, name in zip(ims, txts): + if "." not in name[-5:]: + name = name + ".png" + savepath = os.path.join(self.img_dir, name) + im.save(savepath) + paths.append(savepath) + names = [os.path.basename(p) for p in paths] + return self.add_images(names, + names, + names, + width) + + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + #img(style="width:%dpx" % width, src=os.path.join('images', im)) + img(src=os.path.join('images', im)) + br() + p(txt, style="font-size:9px") + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/swapae/util/iter_counter.py b/swapae/util/iter_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..abee50ed496c42809642631bb8c5cfbd2b86a415 --- /dev/null +++ b/swapae/util/iter_counter.py @@ -0,0 +1,93 @@ +import os +import numpy as np +import torch +import time + + +class IterationCounter(): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--total_nimgs", default=25 * + (1000 ** 2), type=int) + parser.add_argument("--save_freq", default=50000, type=int) + parser.add_argument("--evaluation_freq", default=50000, type=int) + parser.add_argument("--print_freq", default=480, type=int) + parser.add_argument("--display_freq", default=1600, type=int) + return parser + + def __init__(self, opt): + self.opt = opt + self.iter_record_path = os.path.join( + self.opt.checkpoints_dir, self.opt.name, 'iter.txt') + self.steps_so_far = 0 + if "unaligned" in opt.dataset_mode: + self.batch_size = opt.batch_size * 2 + else: + self.batch_size = opt.batch_size + self.time_measurements = {} + + automatically_find_resume_iter = opt.isTrain and opt.continue_train \ + and opt.resume_iter == "latest" and opt.pretrained_name is None + resume_at_specified_iter = opt.isTrain and opt.continue_train \ + and opt.resume_iter.replace("k", "").isnumeric() + if automatically_find_resume_iter: + try: + self.steps_so_far = np.loadtxt( + self.iter_record_path, delimiter=',', dtype=int) + print('Resuming from iteration %d' % (self.steps_so_far)) + except Exception: + print('Could not load iteration record at %s. ' + 'Starting from beginning.' % self.iter_record_path) + elif resume_at_specified_iter: + steps = int(opt.resume_iter.replace("k", "")) + if "k" in opt.resume_iter: + steps *= 1000 + self.steps_so_far = steps + else: + self.steps_so_far = 0 + + def record_one_iteration(self): + if self.needs_saving(): + np.savetxt(self.iter_record_path, + [self.steps_so_far], delimiter=',', fmt='%d') + print("Saved current iter count at %s" % self.iter_record_path) + self.steps_so_far += self.batch_size + + def needs_saving(self): + return (self.steps_so_far % self.opt.save_freq) < self.batch_size + + def needs_evaluation(self): + return (self.steps_so_far >= self.opt.evaluation_freq) and \ + ((self.steps_so_far % self.opt.evaluation_freq) < self.batch_size) + + def needs_printing(self): + return (self.steps_so_far % self.opt.print_freq) < self.batch_size + + def needs_displaying(self): + return (self.steps_so_far % self.opt.display_freq) < self.batch_size + + def completed_training(self): + return (self.steps_so_far >= self.opt.total_nimgs) + + class TimeMeasurement: + def __init__(self, name, parent): + self.name = name + self.parent = parent + + def __enter__(self): + self.start_time = time.time() + + def __exit__(self, type, value, traceback): + torch.cuda.synchronize() + start_time = self.start_time + elapsed_time = (time.time() - start_time) / self.parent.batch_size + + if self.name not in self.parent.time_measurements: + self.parent.time_measurements[self.name] = elapsed_time + else: + prev_time = self.parent.time_measurements[self.name] + updated_time = prev_time * 0.98 + elapsed_time * 0.02 + self.parent.time_measurements[self.name] = updated_time + + def time_measurement(self, name): + return IterationCounter.TimeMeasurement(name, self) diff --git a/swapae/util/kmeans.py b/swapae/util/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb548bcb4b52e19974fb2242a3c40056d08507a --- /dev/null +++ b/swapae/util/kmeans.py @@ -0,0 +1,146 @@ +# From kmeans_pytorch + +import numpy as np +import torch +from tqdm import tqdm + + +def initialize(X, num_clusters): + """ + initialize cluster centers + :param X: (torch.tensor) matrix + :param num_clusters: (int) number of clusters + :return: (np.array) initial state + """ + num_samples = len(X) + indices = np.random.choice(num_samples, num_clusters, replace=False) + initial_state = X[indices] + return initial_state + + +def kmeans( + X, + num_clusters, + distance='euclidean', + tol=1e-4, + device=torch.device('cuda') +): + """ + perform kmeans + :param X: (torch.tensor) matrix + :param num_clusters: (int) number of clusters + :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] + :param tol: (float) threshold [default: 0.0001] + :param device: (torch.device) device [default: cpu] + :return: (torch.tensor, torch.tensor) cluster ids, cluster centers + """ + print(f'running k-means on {device}..') + + if distance == 'euclidean': + pairwise_distance_function = pairwise_distance + elif distance == 'cosine': + pairwise_distance_function = pairwise_cosine + else: + raise NotImplementedError + + # convert to float + X = X.float() + + # transfer to device + X = X.to(device) + + # initialize + initial_state = initialize(X, num_clusters) + + iteration = 0 + tqdm_meter = tqdm(desc='[running kmeans]') + while True: + dis = pairwise_distance_function(X, initial_state) + + choice_cluster = torch.argmin(dis, dim=1) + + initial_state_pre = initial_state.clone() + + for index in range(num_clusters): + selected = torch.nonzero(choice_cluster == index).squeeze().to(device) + + selected = torch.index_select(X, 0, selected) + initial_state[index] = selected.mean(dim=0) + + center_shift = torch.sum( + torch.sqrt( + torch.sum((initial_state - initial_state_pre) ** 2, dim=1) + )) + + # increment iteration + iteration = iteration + 1 + + # update tqdm meter + tqdm_meter.set_postfix( + iteration=f'{iteration}', + center_shift=f'{center_shift ** 2:0.6f}', + tol=f'{tol:0.6f}' + ) + tqdm_meter.update() + if center_shift ** 2 < tol: + break + + return choice_cluster, initial_state + + +def kmeans_predict( + X, + cluster_centers, + distance='euclidean', + device=torch.device('cpu') +): + """ + predict using cluster centers + :param X: (torch.tensor) matrix + :param cluster_centers: (torch.tensor) cluster centers + :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] + :param device: (torch.device) device [default: 'cpu'] + :return: (torch.tensor) cluster ids + """ + print(f'predicting on {device}..') + + if distance == 'euclidean': + pairwise_distance_function = pairwise_distance + elif distance == 'cosine': + pairwise_distance_function = pairwise_cosine + else: + raise NotImplementedError + + # convert to float + X = X.float() + + # transfer to device + X = X.to(device) + + dis = pairwise_distance_function(X, cluster_centers) + choice_cluster = torch.argmin(dis, dim=1) + + return choice_cluster.cpu() + + +def pairwise_distance(data1, data2): + return torch.cdist(data1[None, :, :], data2[None, :, :])[0] + + +def pairwise_cosine(data1, data2): + + # N*1*M + A = data1.unsqueeze(dim=1) + + # 1*N*M + B = data2.unsqueeze(dim=0) + + # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] + A_normalized = A / A.norm(dim=-1, keepdim=True) + B_normalized = B / B.norm(dim=-1, keepdim=True) + + cosine = A_normalized * B_normalized + + # return N*N matrix for pairwise distance + cosine_dis = 1 - cosine.sum(dim=-1).squeeze() + return cosine_dis diff --git a/swapae/util/metric_tracker.py b/swapae/util/metric_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4f4752ec01a517267e0297d43da8652d89cc38 --- /dev/null +++ b/swapae/util/metric_tracker.py @@ -0,0 +1,28 @@ +from collections import OrderedDict + + +class MetricTracker: + def __init__(self, opt): + self.opt = opt + self.metrics = {} + + def moving_average(self, old, new): + s = 0.98 + return old * (s) + new * (1 - s) + + def update_metrics(self, metric_dict, smoothe=True): + default_smoothe = smoothe + for k, v in metric_dict.items(): + if k == "D_R1": + smoothe = False + else: + smoothe = default_smoothe + if k in self.metrics and smoothe: + self.metrics[k] = self.moving_average(self.metrics[k], v) + else: + self.metrics[k] = v + + def current_metrics(self): + keys = sorted(list(self.metrics.keys())) + ordered_metrics = OrderedDict([(k, self.metrics[k]) for k in keys]) + return ordered_metrics diff --git a/swapae/util/pca.py b/swapae/util/pca.py new file mode 100644 index 0000000000000000000000000000000000000000..50936b182047ed4b6ee420ec7aa7e4e0473a84ce --- /dev/null +++ b/swapae/util/pca.py @@ -0,0 +1,69 @@ +import time +import numpy as np +import torch +import swapae.util as util + +np.set_printoptions(precision=4, suppress=True, edgeitems=10) + + +class PCA: + def __init__(self, X, ndim=128, var_fraction=0.99, l2_normalized=True, first_direction=None): + + self.l2_normalized = l2_normalized + if l2_normalized: + X = X[:, :-1] + + assert len(X.shape) == 2 + torch.cuda.synchronize() + start_time = time.time() + + self.mean = torch.mean(X, dim=0, keepdim=True) + #self.mean = 0 + #self.std = torch.std(X, dim=0, keepdim=True) + 1e-6 + self.std = 1 + #print("std is ", self.std[:, :10].cpu().numpy()) + #X_orig = X + X = (X - self.mean) / self.std + + U, S, V = torch.svd(X) + S = S[:ndim] + V = V[:, :ndim] + self.proj = V + scale = torch.mm(X, self.proj).std(dim=0) + torch.cuda.synchronize() + print("PCA time taken on vectors of size %s : %f" % (str(X.size()), time.time() - start_time)) + print("largest std of each PC: ", scale[:10].cpu().numpy()) + print("smallest std of each PC: ", scale[-10:].cpu().numpy()) + self.sinvals = S + print("largest sinvals: ", self.sinvals[:10].cpu().numpy()) + self.inv_proj = V.transpose(0, 1) + self.N = X.size(0) + + def project(self, x): + if self.l2_normalized: + last_dim = x[:, -1:] + x = x[:, :-1] + #x = (x - self.mean) / self.std + z = torch.mm(x, self.proj) + if self.l2_normalized: + return torch.cat([z, last_dim], dim=1) + else: + return z + + def scale(self): + return self.sinvals / np.sqrt(self.N) + + def pc(self, idx): + # return self.inv_proj[idx:idx + 1] * (self.std * np.sqrt(self.inv_proj.size(1))) + return self.inv_proj[idx:idx + 1] + + def inverse(self, z): + if self.l2_normalized: + last_dim = z[:, -1:] + z = z[:, :-1] + x = torch.mm(z, self.inv_proj) + #x = x * self.std + self.mean + if self.l2_normalized: + return torch.cat([x, last_dim], dim=1) + else: + return x diff --git a/swapae/util/util.py b/swapae/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..934ec56dbda48d78d3a581d0db437701abe0e4ce --- /dev/null +++ b/swapae/util/util.py @@ -0,0 +1,556 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numbers +import torch.nn as nn +import torchvision +import torch.nn.functional as F +import math +import numpy as np +from PIL import Image +import os +import importlib +import argparse +from argparse import Namespace +from sklearn.decomposition import PCA as PCA + + +def normalize(v): + if type(v) == list: + return [normalize(vv) for vv in v] + + return v * torch.rsqrt((torch.sum(v ** 2, dim=1, keepdim=True) + 1e-8)) + +def slerp(a, b, r): + d = torch.sum(a * b, dim=-1, keepdim=True) + p = r * torch.acos(d * (1 - 1e-4)) + c = normalize(b - d * a) + d = a * torch.cos(p) + c * torch.sin(p) + return normalize(d) + + +def lerp(a, b, r): + if type(a) == list or type(a) == tuple: + return [lerp(aa, bb, r) for aa, bb in zip(a, b)] + return a * (1 - r) + b * r + + +def madd(a, b, r): + if type(a) == list or type(a) == tuple: + return [madd(aa, bb, r) for aa, bb in zip(a, b)] + return a + b * r + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def copyconf(default_opt, **kwargs): + conf = Namespace(**vars(default_opt)) + for key in kwargs: + setattr(conf, key, kwargs[key]) + return conf + + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) + + return cls + + +def tile_images(imgs, picturesPerRow=4): + """ Code borrowed from + https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 + """ + + # Padding + if imgs.shape[0] % picturesPerRow == 0: + rowPadding = 0 + else: + rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow + if rowPadding > 0: + imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) + + # Tiling Loop (The conditionals are not necessary anymore) + tiled = [] + for i in range(0, imgs.shape[0], picturesPerRow): + tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) + + tiled = np.concatenate(tiled, axis=0) + return tiled + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=2): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + + if len(image_tensor.shape) == 4: + # transform each image in the batch + images_np = [] + for b in range(image_tensor.shape[0]): + one_image = image_tensor[b] + one_image_np = tensor2im(one_image) + images_np.append(one_image_np.reshape(1, *one_image_np.shape)) + images_np = np.concatenate(images_np, axis=0) + if tile is not False: + tile = max(min(images_np.shape[0] // 2, 4), 1) if tile is True else tile + images_tiled = tile_images(images_np, picturesPerRow=tile) + return images_tiled + else: + return images_np + + if len(image_tensor.shape) == 2: + assert False + #imagce_tensor = image_tensor.unsqueeze(0) + image_numpy = image_tensor.detach().cpu().numpy() if type(image_tensor) is not np.ndarray else image_tensor + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = np.repeat(image_numpy, 3, axis=2) + return image_numpy.astype(imtype) + + +def toPILImage(images, tile=None): + if isinstance(images, list): + if all(['tensor' in str(type(image)).lower() for image in images]): + return toPILImage(torch.cat([im.cpu() for im in images], dim=0), tile) + return [toPILImage(image, tile=tile) for image in images] + + if 'ndarray' in str(type(images)).lower(): + return toPILImage(torch.from_numpy(images)) + + assert 'tensor' in str(type(images)).lower(), "input of type %s cannot be handled." % str(type(images)) + + if tile is None: + max_width = 2560 + tile = min(images.size(0), int(max_width / images.size(3))) + + return Image.fromarray(tensor2im(images, tile=tile)) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio is None: + pass + elif aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + elif aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + + +def visualize_spatial_code(sp): + device = sp.device + #sp = (sp - sp.min()) / (sp.max() - sp.min() + 1e-7) + if sp.size(1) <= 2: + sp = sp.repeat([1, 3, 1, 1])[:, :3, :, :] + if sp.size(1) == 3: + pass + else: + sp = sp.detach().cpu().numpy() + X = np.transpose(sp, (0, 2, 3, 1)) + B, H, W = X.shape[0], X.shape[1], X.shape[2] + X = np.reshape(X, (-1, X.shape[3])) + X = X - X.mean(axis=0, keepdims=True) + try: + Z = PCA(3).fit_transform(X) + except ValueError: + print("Running PCA on the structure code has failed.") + print("This is likely a bug of scikit-learn in version 0.18.1.") + print("https://stackoverflow.com/a/42764378") + print("The visualization of the structure code on visdom won't work.") + return torch.zeros(B, 3, H, W, device=device) + sp = np.transpose(np.reshape(Z, (B, H, W, -1)), (0, 3, 1, 2)) + sp = (sp - sp.min()) / (sp.max() - sp.min()) * 2 - 1 + sp = torch.from_numpy(sp).to(device) + return sp + + +def blank_tensor(w, h): + return torch.ones(1, 3, h, w) + + + +class RandomSpatialTransformer: + def __init__(self, opt, bs): + self.opt = opt + #self.resample_transformation(bs) + + + def create_affine_transformation(self, ref, rot, sx, sy, tx, ty): + return torch.stack([-ref * sx * torch.cos(rot), -sy * torch.sin(rot), tx, + -ref * sx * torch.sin(rot), sy * torch.cos(rot), ty], axis=1) + + def resample_transformation(self, bs, device, reflection=None, rotation=None, scale=None, translation=None): + dev = device + zero = torch.zeros((bs), device=dev) + if reflection is None: + #if "ref" in self.opt.random_transformation_mode: + ref = torch.round(torch.rand((bs), device=dev)) * 2 - 1 + #else: + # ref = 1.0 + else: + ref = reflection + + if rotation is None: + #if "rot" in self.opt.random_transformation_mode: + max_rotation = 30 * math.pi / 180 + rot = torch.rand((bs), device=dev) * (2 * max_rotation) - max_rotation + #else: + # rot = 0.0 + else: + rot = rotation + + if scale is None: + #if "scale" in self.opt.random_transformation_mode: + min_scale = 1.0 + max_scale = 1.0 + sx = torch.rand((bs), device=dev) * (max_scale - min_scale) + min_scale + sy = torch.rand((bs), device=dev) * (max_scale - min_scale) + min_scale + #else: + # sx, sy = 1.0, 1.0 + else: + sx, sy = scale + + tx, ty = zero, zero + + A = torch.stack([ref * sx * torch.cos(rot), -sy * torch.sin(rot), tx, + ref * sx * torch.sin(rot), sy * torch.cos(rot), ty], axis=1) + return A.view(bs, 2, 3) + + + + def forward_transform(self, x, size): + if type(x) == list: + return [self.forward_transform(xx) for xx in x] + + affine_param = self.resample_transformation(x.size(0), x.device) + affine_grid = F.affine_grid(affine_param, (x.size(0), x.size(1), size[0], size[1]), align_corners=False) + x = F.grid_sample(x, affine_grid, padding_mode='reflection', align_corners=False) + + return x + + +def apply_random_crop(x, target_size, scale_range, num_crops=1, return_rect=False): + # build grid + B = x.size(0) * num_crops + flip = torch.round(torch.rand(B, 1, 1, 1, device=x.device)) * 2 - 1.0 + unit_grid_x = torch.linspace(-1.0, 1.0, target_size, device=x.device)[np.newaxis, np.newaxis, :, np.newaxis].repeat(B, target_size, 1, 1) + unit_grid_y = unit_grid_x.transpose(1, 2) + unit_grid = torch.cat([unit_grid_x * flip, unit_grid_y], dim=3) + + + #crops = [] + x = x.unsqueeze(1).expand(-1, num_crops, -1, -1, -1).flatten(0, 1) + #for i in range(num_crops): + scale = torch.rand(B, 1, 1, 2, device=x.device) * (scale_range[1] - scale_range[0]) + scale_range[0] + offset = (torch.rand(B, 1, 1, 2, device=x.device) * 2 - 1) * (1 - scale) + sampling_grid = unit_grid * scale + offset + crop = F.grid_sample(x, sampling_grid, align_corners=False) + #crops.append(crop) + #crop = torch.stack(crops, dim=1) + crop = crop.view(B // num_crops, num_crops, crop.size(1), crop.size(2), crop.size(3)) + + return crop + + + + +def five_crop_noresize(A): + Y, X = A.size(2) // 3, A.size(3) // 3 + H, W = Y * 2, X * 2 + return torch.stack([A[:, :, 0:0+H, 0:0+W], + A[:, :, Y:Y+H, 0:0+W], + A[:, :, Y:Y+H, X:X+W], + A[:, :, 0:0+H, X:X+W], + A[:, :, Y//2:Y//2+H, X//2:X//2+W]], + dim=1) # return 5-dim tensor + + +def random_crop_noresize(A, crop_size): + offset_y = np.random.randint(A.size(2) - crop_size[0]) + offset_x = np.random.randint(A.size(3) - crop_size[1]) + return A[:, :, offset_y:offset_y + crop_size[0], offset_x:offset_x + crop_size[1]], (offset_y, offset_x) + + +def random_crop_with_resize(A, crop_size): + #size_y = np.random.randint(crop_size[0], A.size(2) + 1) + #size_x = np.random.randint(crop_size[1], A.size(3) + 1) + #size_y, size_x = crop_size + size_y = max(crop_size[0], np.random.randint(A.size(2) // 3, A.size(2) + 1)) + size_x = max(crop_size[1], np.random.randint(A.size(3) // 3, A.size(3) + 1)) + offset_y = np.random.randint(A.size(2) - size_y + 1) + offset_x = np.random.randint(A.size(3) - size_x + 1) + crop_rect = (offset_y, offset_x, size_y, size_x) + resized = crop_with_resize(A, crop_rect, crop_size) + #print('resized %s to %s' % (A.size(), resized.size())) + return resized, crop_rect + + +def crop_with_resize(A, crop_rect, return_size): + offset_y, offset_x, size_y, size_x = crop_rect + crop = A[:, :, offset_y:offset_y + size_y, offset_x:offset_x + size_x] + resized = F.interpolate(crop, size=return_size, mode='bilinear', align_corners=False) + #print('resized %s to %s' % (A.size(), resized.size())) + return resized + + +def compute_similarity_logit(x, y, p=1, compute_interdistances=True): + + def compute_dist(x, y, p): + if p == 2: + return ((x - y) ** 2).sum(dim=-1).sqrt() + else: + return (x - y).abs().sum(dim=-1) + C = x.shape[-1] + + if len(x.shape) == 2: + if compute_interdistances: + dist = torch.cdist(x[None, :, :], y[None, :, :], p)[0] + else: + dist = compute_dist(x, y, p) + if len(x.shape) == 3: + if compute_interdistances: + dist = torch.cdist(x, y, p) + else: + dist = compute_dist(x, y, p) + + if p == 1: + dist = 1 - dist / math.sqrt(C) + elif p == 2: + dist = 1 - 0.5 * (dist ** 2) + + return dist / 0.07 + + +def set_diag_(x, value): + assert x.size(-2) == x.size(-1) + L = x.size(-2) + identity = torch.eye(L, dtype=torch.bool, device=x.device) + identity = identity.view([1] * (len(x.shape) - 2) + [L, L]) + x.masked_fill_(identity, value) + + +def to_numpy(metric_dict): + new_dict = {} + for k, v in metric_dict.items(): + if "numpy" not in str(type(v)): + v = v.detach().cpu().mean().numpy() + new_dict[k] = v + return new_dict + + +def is_custom_kernel_supported(): + version_str = str(torch.version.cuda).split(".") + major = version_str[0] + minor = version_str[1] + return int(major) >= 10 and int(minor) >= 1 + + +def shuffle_batch(x): + B = x.size(0) + perm = torch.randperm(B, dtype=torch.long, device=x.device) + return x[perm] + + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return tuple(reversed(out)) + + +def quantize_color(x, num=64): + return (x * num / 2).round() * (2 / num) + + +def resize2d_tensor(x, size_or_tensor_of_size): + if torch.is_tensor(size_or_tensor_of_size): + size = size_or_tensor_of_size.size() + elif isinstance(size_or_tensor_of_size, np.ndarray): + size = size_or_tensor_of_size.shape + else: + size = size_or_tensor_of_size + + if isinstance(size, tuple) or isinstance(size, list): + return F.interpolate(x, size[-2:], + mode='bilinear', align_corners=False) + else: + raise ValueError("%s is unrecognized" % str(type(size))) + + +def correct_resize(t, size, mode=Image.BICUBIC): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i:i+1] + one_image = Image.fromarray(tensor2im(one_t, tile=1)).resize(size, Image.BICUBIC) + resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + + + + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + self.pad_size = kernel_size // 2 + kernel_size = [kernel_size] * dim + else: + raise NotImplementedError() + + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / (torch.sum(kernel)) + + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + x = F.pad(input, [self.pad_size] * 4, mode="reflect") + return self.conv(x, weight=self.weight, groups=self.groups) diff --git a/swapae/util/visualizer.py b/swapae/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c62e6d2cb8677df675e995feb41355ff9cd1b23d --- /dev/null +++ b/swapae/util/visualizer.py @@ -0,0 +1,291 @@ +import numpy as np +import torch +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from func_timeout import func_timeout, FunctionTimedOut + + +if sys.version_info[0] == 2: + VisdomExceptionBase = Exception +else: + VisdomExceptionBase = ConnectionError + + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s/%s.png' % (label, name) + os.makedirs(os.path.join(image_dir, label), exist_ok=True) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument("--display_port", default=2004) + parser.add_argument("--display_ncols", default=2) + parser.add_argument("--display_env", default="main") + parser.add_argument("--no_html", type=util.str2bool, nargs='?', const=True, default=True) + + return parser + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: connect to a visdom server + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.display_id = np.random.randint(1000000) * 10 # just a random display id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.crop_size + self.name = opt.name + self.port = opt.display_port + self.saved = False + if self.display_id > 0: + # connect to a visdom server + import visdom + self.plot_data = {} + self.ncols = opt.display_ncols + if "tensorboard_base_url" in os.environ: + self.vis = visdom.Visdom( + port=2004, + base_url=os.environ['tensorboard_base_url'] + '/visdom', + env=opt.display_env, + #raise_exceptions=False, + ) + print("setting up visdom server for sensei") + else: + self.vis = visdom.Visdom( + server="http://localhost", + port=opt.display_port, + env=opt.display_env, + raise_exceptions=False) + if not self.vis.check_connection(): + self.create_visdom_connections() + + if self.use_html: + # Create an HTML object at /web/; + # Images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + + # create a logging file to store training losses + self.log_name = os.path.join( + opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + def create_visdom_connections(self): + """If the program could not connect to Visdom server, + this function will start a new server at port < self.port > """ + cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port + print('\n\nCould not connect to Visdom server. ' + '\n Trying to start a server....') + print('Command: %s' % cmd) + Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) + + def display_current_results(self, visuals, epoch, + save_result=None, max_num_images=4): + """Display current results on visdom; + save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + if save_result is None: + save_result = not self.opt.no_html + if self.display_id > 0: # show images in the browser using visdom + ncols = self.ncols + if ncols > 0: # show all the images in one visdom panel + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) # create a table css + # create a table of images. + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + if image.size(3) < 64: + image = torch.nn.functional.interpolate( + image, size=(64, 64), + mode='bilinear', align_corners=False) + image_numpy = util.tensor2im(image[:max_num_images]) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like( + image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + try: + func_timeout(15, self.vis.images, + args=(images, ncols, 2, self.display_id + 1, + None, dict(title=title + ' images'))) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, + win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except FunctionTimedOut: + print("visdom call to display image timed out") + pass + except VisdomExceptionBase: + self.create_visdom_connections() + + else: # show each image in a separate visdom panel; + idx = 1 + try: + for label, image in visuals.items(): + image_numpy = util.tensor2im(image[:4]) + try: + func_timeout(5, self.vis.image, args=( + image_numpy.transpose([2, 0, 1]), + self.display_id + idx, + None, + dict(title=label) + )) + except FunctionTimedOut: + print("visdom call to display image timed out") + pass + idx += 1 + except VisdomExceptionBase: + self.create_visdom_connections() + + needs_save = save_result or not self.saved + if self.use_html and needs_save: + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image[:4]) + img_path = os.path.join( + self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML( + self.web_dir, 'Experiment name = %s' % self.name, refresh=0) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, epoch, counter_ratio, losses): + """display the current losses on visdom display: dictionary of error labels and values + + Parameters: + epoch (int) -- current epoch + counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + """ + if len(losses) == 0: + return + + plot_name = '_'.join(list(losses.keys())) + + if plot_name not in self.plot_data: + self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())} + + plot_data = self.plot_data[plot_name] + plot_id = list(self.plot_data.keys()).index(plot_name) + + plot_data['X'].append(epoch + counter_ratio) + plot_data['Y'].append([losses[k] for k in plot_data['legend']]) + try: + self.vis.line( + X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1), + Y=np.array(plot_data['Y']), + opts={ + 'title': self.name, + 'legend': plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id - plot_id) + except VisdomExceptionBase: + self.create_visdom_connections() + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, iters, times, losses): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(iters: %d' % (iters) + for k, v in times.items(): + message += ", %s: %.3f" % (k, v) + message += ") " + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v.mean()) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/tmp/0.png b/tmp/0.png new file mode 100644 index 0000000000000000000000000000000000000000..9ed2bcb09e9ed8c3a8d76b3fb76bb404e994d8f3 Binary files /dev/null and b/tmp/0.png differ diff --git a/tmp/1.png b/tmp/1.png new file mode 100644 index 0000000000000000000000000000000000000000..bdffc23b77c123cc712b13e5a33cafa2fa60c7fa Binary files /dev/null and b/tmp/1.png differ diff --git a/tmp/2.png b/tmp/2.png new file mode 100644 index 0000000000000000000000000000000000000000..e9ab8d1a5d8e6d22ff23d2ac56876adb003f643c Binary files /dev/null and b/tmp/2.png differ diff --git a/tmp/3.png b/tmp/3.png new file mode 100644 index 0000000000000000000000000000000000000000..43bdb509c9397f8475d4fa7a86c6819b982d3d59 Binary files /dev/null and b/tmp/3.png differ diff --git a/tmp/4.png b/tmp/4.png new file mode 100644 index 0000000000000000000000000000000000000000..bcd3493a42c3d785ae10ee9f3b695ca92d3d160a Binary files /dev/null and b/tmp/4.png differ diff --git a/tmp/5.png b/tmp/5.png new file mode 100644 index 0000000000000000000000000000000000000000..8d9bdc3d80b5b7a6ba0abb1ec075b649cf8a6137 Binary files /dev/null and b/tmp/5.png differ diff --git a/tmp/6.png b/tmp/6.png new file mode 100644 index 0000000000000000000000000000000000000000..97b8f2b4a2fd4158e6bab3d10b8be8c0d2bc0ee5 Binary files /dev/null and b/tmp/6.png differ diff --git a/tmp/7.png b/tmp/7.png new file mode 100644 index 0000000000000000000000000000000000000000..e7c09f986865295994e14f785503eca54936b6cc Binary files /dev/null and b/tmp/7.png differ diff --git a/tmp/8.png b/tmp/8.png new file mode 100644 index 0000000000000000000000000000000000000000..db5fa3c6c89c1cd584c2a91b313002d1c24a7f22 Binary files /dev/null and b/tmp/8.png differ diff --git a/tmp/9.png b/tmp/9.png new file mode 100644 index 0000000000000000000000000000000000000000..324b305c4fed1d0a2bd2f128fb6e74c83920eaef Binary files /dev/null and b/tmp/9.png differ diff --git a/weights/.gitattributes b/weights/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..ec4a626fbb7799f6a25b45fb86344b2bf7b37e64 --- /dev/null +++ b/weights/.gitattributes @@ -0,0 +1 @@ +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/weights/12003/cpk.pth b/weights/12003/cpk.pth new file mode 100644 index 0000000000000000000000000000000000000000..19491a20be6d55b77a2e505453afa90e22a6531e --- /dev/null +++ b/weights/12003/cpk.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c61f11b2987362d958f9c93c7c20f99e06b586c022bc036b9aa78f3ba9ed16e +size 186192007 diff --git a/weights/12003/exp_args.json b/weights/12003/exp_args.json new file mode 100644 index 0000000000000000000000000000000000000000..9a93810d0cfbc36f94471051ed8c155025afc885 --- /dev/null +++ b/weights/12003/exp_args.json @@ -0,0 +1,63 @@ +{ + "data_path": "/home/xtli/DATA/BSR_processed/train", + "img_path": "data/images/12003.jpg", + "test_path": null, + "crop_size": 224, + "scale_size": null, + "batch_size": 1, + "workers": 4, + "pretrained_path": "/home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth", + "hidden_dim": 256, + "spatial_code_dim": 32, + "tex_code_dim": 256, + "exp_name": "04-18/12003", + "project_name": "ssn_transformer", + "nepochs": 20, + "lr": 5e-05, + "momentum": 0.5, + "beta": 0.999, + "lr_decay_freq": 3000, + "save_freq": 1000, + "save_freq_iter": 2000000000000, + "log_freq": 10, + "display_freq": 100, + "use_wandb": 0, + "work_dir": "/home/xtli/WORKDIR", + "out_dir": "/home/xtli/WORKDIR/04-18/12003", + "local_rank": 0, + "dataset": "dataset", + "config_file": "models/week0417/json/single_scale_grouping_ft.json", + "lambda_L1": 1, + "lambda_Perceptual": 1.0, + "lambda_PatchGAN": 1.0, + "lambda_GAN": 1, + "add_gan_epoch": 0, + "lambda_kld_loss": 1e-06, + "lambda_style_loss": 1.0, + "lambda_feat": 10.0, + "use_slic": true, + "patch_size": 64, + "netPatchD_scale_capacity": 4.0, + "netPatchD_max_nc": 384, + "netPatchD": "StyleGAN2", + "use_antialias": true, + "patch_use_aggregation": false, + "lambda_R1": 1.0, + "lambda_ffl_loss": 1.0, + "lambda_patch_R1": 1.0, + "R1_once_every": 16, + "add_self_loops": 1, + "test_time": false, + "sp_num": 196, + "label_path": "/home/xtli/DATA/BSR/BSDS500/data/groundTruth", + "model_name": "model", + "num_D": 2, + "n_layers_D": 3, + "n_cluster": 10, + "temperature": 23, + "add_gcn_epoch": 0, + "add_clustering_epoch": 0, + "add_texture_epoch": 0, + "dec_input_mode": "sine_wave_noise", + "sine_weight": true +} diff --git a/weights/12003/exp_args.txt b/weights/12003/exp_args.txt new file mode 100644 index 0000000000000000000000000000000000000000..1aeac13ba316748a1cd5321b1f17fd70725406ab --- /dev/null +++ b/weights/12003/exp_args.txt @@ -0,0 +1,60 @@ + add_clustering_epoch: 0 [default: 1000] + add_gcn_epoch: 0 [default: None] + add_self_loops: 1 + add_texture_epoch: 0 [default: 1000] + batch_size: 1 + beta: 0.999 + config_file: models/week0417/json/single_scale_grouping_ft.json [default: None] + crop_size: 224 + data_path: /home/xtli/DATA/BSR_processed/train_extend + dataset: dataset [default: None] + dec_input_mode: sine_wave_noise [default: None] + display_freq: 100 + exp_name: 04-18/12003 [default: None] + gumbel: 0 + hidden_dim: 256 [default: None] + img_path: None + l1_loss_wt: 1.0 + lambda_GAN: 1 [default: None] + lambda_L1: 1 [default: None] + lambda_style_loss: 1.0 [default: None] + local_rank: None + log_freq: 10 + lr: 5e-05 [default: 0.1] + lr_decay_freq: 3000 + maxIter: 1000 + model_name: model [default: None] + momentum: 0.5 + nChannel: 100 + nConv: 2 + n_cluster: 10 [default: None] + n_layers_D: 3 [default: None] + nepochs: 100 [default: None] + netE_nc_steepness: 2.0 + netE_num_downsampling_gl: 2 + netE_num_downsampling_sp: 4 + netE_scale_capacity: 1.0 +netG_num_base_resnet_layers: 2 + netG_resnet_ch: 256 + netG_scale_capacity: 1.0 + no_ganFeat_loss: False + num_D: 2 [default: None] + num_classes: 0 + out_dir: /home/xtli/WORKDIR/04-18/12003 [default: None] + patch_size: 40 + perceptual_loss_wt: 1.0 + pretrained_ae: /home/xtli/WORKDIR/07-16/transformer/cpk.pth + pretrained_path: /home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth [default: None] + project_name: test_time + save_freq: 1000 [default: 2000] + sine_weight: 1 [default: None] + sp_num: None + spatial_code_ch: 8 + spatial_code_dim: 32 [default: 256] + temperature: 23 [default: 1] + test_time: 0 + texture_code_ch: 256 + use_slic: True + use_wandb: False + work_dir: /home/xtli/WORKDIR + workers: 4