feng2022 commited on
Commit
bd964f4
1 Parent(s): 48eab26

add contextual_Loss

Browse files
Time_TravelRephotography/losses/contextual_loss/.gitignore ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
Time_TravelRephotography/losses/contextual_loss/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Sou Uchida
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/losses/contextual_loss/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import *
Time_TravelRephotography/losses/contextual_loss/config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # TODO: add supports for L1, L2 etc.
2
+ LOSS_TYPES = ['cosine', 'l1', 'l2']
Time_TravelRephotography/losses/contextual_loss/functional.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .config import LOSS_TYPES
5
+
6
+ __all__ = ['contextual_loss', 'contextual_bilateral_loss']
7
+
8
+
9
+ def contextual_loss(x: torch.Tensor,
10
+ y: torch.Tensor,
11
+ band_width: float = 0.5,
12
+ loss_type: str = 'cosine',
13
+ all_dist: bool = False):
14
+ """
15
+ Computes contextual loss between x and y.
16
+ The most of this code is copied from
17
+ https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da.
18
+
19
+ Parameters
20
+ ---
21
+ x : torch.Tensor
22
+ features of shape (N, C, H, W).
23
+ y : torch.Tensor
24
+ features of shape (N, C, H, W).
25
+ band_width : float, optional
26
+ a band-width parameter used to convert distance to similarity.
27
+ in the paper, this is described as :math:`h`.
28
+ loss_type : str, optional
29
+ a loss type to measure the distance between features.
30
+ Note: `l1` and `l2` frequently raises OOM.
31
+
32
+ Returns
33
+ ---
34
+ cx_loss : torch.Tensor
35
+ contextual loss between x and y (Eq (1) in the paper)
36
+ """
37
+
38
+ assert x.size() == y.size(), 'input tensor must have the same size.'
39
+ assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
40
+
41
+ N, C, H, W = x.size()
42
+
43
+ if loss_type == 'cosine':
44
+ dist_raw = compute_cosine_distance(x, y)
45
+ elif loss_type == 'l1':
46
+ dist_raw = compute_l1_distance(x, y)
47
+ elif loss_type == 'l2':
48
+ dist_raw = compute_l2_distance(x, y)
49
+
50
+ dist_tilde = compute_relative_distance(dist_raw)
51
+ cx = compute_cx(dist_tilde, band_width)
52
+ if all_dist:
53
+ return cx
54
+
55
+ cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1)
56
+ cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5)
57
+
58
+ return cx_loss
59
+
60
+
61
+ # TODO: Operation check
62
+ def contextual_bilateral_loss(x: torch.Tensor,
63
+ y: torch.Tensor,
64
+ weight_sp: float = 0.1,
65
+ band_width: float = 1.,
66
+ loss_type: str = 'cosine'):
67
+ """
68
+ Computes Contextual Bilateral (CoBi) Loss between x and y,
69
+ proposed in https://arxiv.org/pdf/1905.05169.pdf.
70
+
71
+ Parameters
72
+ ---
73
+ x : torch.Tensor
74
+ features of shape (N, C, H, W).
75
+ y : torch.Tensor
76
+ features of shape (N, C, H, W).
77
+ band_width : float, optional
78
+ a band-width parameter used to convert distance to similarity.
79
+ in the paper, this is described as :math:`h`.
80
+ loss_type : str, optional
81
+ a loss type to measure the distance between features.
82
+ Note: `l1` and `l2` frequently raises OOM.
83
+
84
+ Returns
85
+ ---
86
+ cx_loss : torch.Tensor
87
+ contextual loss between x and y (Eq (1) in the paper).
88
+ k_arg_max_NC : torch.Tensor
89
+ indices to maximize similarity over channels.
90
+ """
91
+
92
+ assert x.size() == y.size(), 'input tensor must have the same size.'
93
+ assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
94
+
95
+ # spatial loss
96
+ grid = compute_meshgrid(x.shape).to(x.device)
97
+ dist_raw = compute_l2_distance(grid, grid)
98
+ dist_tilde = compute_relative_distance(dist_raw)
99
+ cx_sp = compute_cx(dist_tilde, band_width)
100
+
101
+ # feature loss
102
+ if loss_type == 'cosine':
103
+ dist_raw = compute_cosine_distance(x, y)
104
+ elif loss_type == 'l1':
105
+ dist_raw = compute_l1_distance(x, y)
106
+ elif loss_type == 'l2':
107
+ dist_raw = compute_l2_distance(x, y)
108
+ dist_tilde = compute_relative_distance(dist_raw)
109
+ cx_feat = compute_cx(dist_tilde, band_width)
110
+
111
+ # combined loss
112
+ cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp
113
+
114
+ k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True)
115
+
116
+ cx = k_max_NC.mean(dim=1)
117
+ cx_loss = torch.mean(-torch.log(cx + 1e-5))
118
+
119
+ return cx_loss
120
+
121
+
122
+ def compute_cx(dist_tilde, band_width):
123
+ w = torch.exp((1 - dist_tilde) / band_width) # Eq(3)
124
+ cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4)
125
+ return cx
126
+
127
+
128
+ def compute_relative_distance(dist_raw):
129
+ dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True)
130
+ dist_tilde = dist_raw / (dist_min + 1e-5)
131
+ return dist_tilde
132
+
133
+
134
+ def compute_cosine_distance(x, y):
135
+ # mean shifting by channel-wise mean of `y`.
136
+ y_mu = y.mean(dim=(0, 2, 3), keepdim=True)
137
+ x_centered = x - y_mu
138
+ y_centered = y - y_mu
139
+
140
+ # L2 normalization
141
+ x_normalized = F.normalize(x_centered, p=2, dim=1)
142
+ y_normalized = F.normalize(y_centered, p=2, dim=1)
143
+
144
+ # channel-wise vectorization
145
+ N, C, *_ = x.size()
146
+ x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W)
147
+ y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W)
148
+
149
+ # consine similarity
150
+ cosine_sim = torch.bmm(x_normalized.transpose(1, 2),
151
+ y_normalized) # (N, H*W, H*W)
152
+
153
+ # convert to distance
154
+ dist = 1 - cosine_sim
155
+
156
+ return dist
157
+
158
+
159
+ # TODO: Considering avoiding OOM.
160
+ def compute_l1_distance(x: torch.Tensor, y: torch.Tensor):
161
+ N, C, H, W = x.size()
162
+ x_vec = x.view(N, C, -1)
163
+ y_vec = y.view(N, C, -1)
164
+
165
+ dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3)
166
+ dist = dist.abs().sum(dim=1)
167
+ dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
168
+ dist = dist.clamp(min=0.)
169
+
170
+ return dist
171
+
172
+
173
+ # TODO: Considering avoiding OOM.
174
+ def compute_l2_distance(x, y):
175
+ N, C, H, W = x.size()
176
+ x_vec = x.view(N, C, -1)
177
+ y_vec = y.view(N, C, -1)
178
+ x_s = torch.sum(x_vec ** 2, dim=1)
179
+ y_s = torch.sum(y_vec ** 2, dim=1)
180
+
181
+ A = y_vec.transpose(1, 2) @ x_vec
182
+ dist = y_s - 2 * A + x_s.transpose(0, 1)
183
+ dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
184
+ dist = dist.clamp(min=0.)
185
+
186
+ return dist
187
+
188
+
189
+ def compute_meshgrid(shape):
190
+ N, C, H, W = shape
191
+ rows = torch.arange(0, H, dtype=torch.float32) / (H + 1)
192
+ cols = torch.arange(0, W, dtype=torch.float32) / (W + 1)
193
+
194
+ feature_grid = torch.meshgrid(rows, cols)
195
+ feature_grid = torch.stack(feature_grid).unsqueeze(0)
196
+ feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0)
197
+
198
+ return feature_grid
Time_TravelRephotography/losses/contextual_loss/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .contextual import ContextualLoss
2
+ from .contextual_bilateral import ContextualBilateralLoss
3
+
4
+ __all__ = ['ContextualLoss', 'ContextualBilateralLoss']
Time_TravelRephotography/losses/contextual_loss/modules/contextual.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import (
3
+ Iterable,
4
+ List,
5
+ Optional,
6
+ )
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .vgg import VGG19
13
+ from .. import functional as F
14
+ from ..config import LOSS_TYPES
15
+
16
+
17
+ class ContextualLoss(nn.Module):
18
+ """
19
+ Creates a criterion that measures the contextual loss.
20
+
21
+ Parameters
22
+ ---
23
+ band_width : int, optional
24
+ a band_width parameter described as :math:`h` in the paper.
25
+ use_vgg : bool, optional
26
+ if you want to use VGG feature, set this `True`.
27
+ vgg_layer : str, optional
28
+ intermidiate layer name for VGG feature.
29
+ Now we support layer names:
30
+ `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ band_width: float = 0.5,
36
+ loss_type: str = 'cosine',
37
+ use_vgg: bool = False,
38
+ vgg_layers: List[str] = ['relu3_4'],
39
+ feature_1d_size: int = 64,
40
+ ):
41
+
42
+ super().__init__()
43
+
44
+ assert band_width > 0, 'band_width parameter must be positive.'
45
+ assert loss_type in LOSS_TYPES,\
46
+ f'select a loss type from {LOSS_TYPES}.'
47
+
48
+ self.loss_type = loss_type
49
+ self.band_width = band_width
50
+ self.feature_1d_size = feature_1d_size
51
+
52
+ if use_vgg:
53
+ self.vgg_model = VGG19()
54
+ self.vgg_layers = vgg_layers
55
+ self.register_buffer(
56
+ name='vgg_mean',
57
+ tensor=torch.tensor(
58
+ [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
59
+ )
60
+ self.register_buffer(
61
+ name='vgg_std',
62
+ tensor=torch.tensor(
63
+ [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
64
+ )
65
+
66
+ def forward(self, x: torch.Tensor, y: torch.Tensor, all_dist: bool = False):
67
+ if not hasattr(self, 'vgg_model'):
68
+ return self.contextual_loss(x, y, self.feature_1d_size, self.band_width, all_dist=all_dist)
69
+
70
+
71
+ x = self.forward_vgg(x)
72
+ y = self.forward_vgg(y)
73
+
74
+ loss = 0
75
+ for layer in self.vgg_layers:
76
+ # picking up vgg feature maps
77
+ fx = getattr(x, layer)
78
+ fy = getattr(y, layer)
79
+ loss = loss + self.contextual_loss(
80
+ fx, fy, self.feature_1d_size, self.band_width, all_dist=all_dist, loss_type=self.loss_type
81
+ )
82
+ return loss
83
+
84
+ def forward_vgg(self, x: torch.Tensor):
85
+ assert x.shape[1] == 3, 'VGG model takes 3 chennel images.'
86
+ # [-1, 1] -> [0, 1]
87
+ x = (x + 1) * 0.5
88
+
89
+ # normalization
90
+ x = x.sub(self.vgg_mean.detach()).div(self.vgg_std)
91
+ return self.vgg_model(x)
92
+
93
+ @classmethod
94
+ def contextual_loss(
95
+ cls,
96
+ x: torch.Tensor, y: torch.Tensor,
97
+ feature_1d_size: int,
98
+ band_width: int,
99
+ all_dist: bool = False,
100
+ loss_type: str = 'cosine',
101
+ ) -> torch.Tensor:
102
+ feature_size = feature_1d_size ** 2
103
+ if np.prod(x.shape[2:]) > feature_size or np.prod(y.shape[2:]) > feature_size:
104
+ x, indices = cls.random_sampling(x, feature_1d_size=feature_1d_size)
105
+ y, _ = cls.random_sampling(y, feature_1d_size=feature_1d_size, indices=indices)
106
+
107
+ return F.contextual_loss(x, y, band_width, all_dist=all_dist, loss_type=loss_type)
108
+
109
+ @staticmethod
110
+ def random_sampling(
111
+ tensor_NCHW: torch.Tensor, feature_1d_size: int, indices: Optional[List] = None
112
+ ):
113
+ N, C, H, W = tensor_NCHW.shape
114
+ S = H * W
115
+ tensor_NCS = tensor_NCHW.reshape([N, C, S])
116
+ if indices is None:
117
+ all_indices = list(range(S))
118
+ random.shuffle(all_indices)
119
+ indices = all_indices[:feature_1d_size**2]
120
+ res = tensor_NCS[:, :, indices].reshape(N, -1, feature_1d_size, feature_1d_size)
121
+ return res, indices
Time_TravelRephotography/losses/contextual_loss/modules/contextual_bilateral.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vgg import VGG19
5
+ from .. import functional as F
6
+ from ..config import LOSS_TYPES
7
+
8
+
9
+ class ContextualBilateralLoss(nn.Module):
10
+ """
11
+ Creates a criterion that measures the contextual bilateral loss.
12
+
13
+ Parameters
14
+ ---
15
+ weight_sp : float, optional
16
+ a balancing weight between spatial and feature loss.
17
+ band_width : int, optional
18
+ a band_width parameter described as :math:`h` in the paper.
19
+ use_vgg : bool, optional
20
+ if you want to use VGG feature, set this `True`.
21
+ vgg_layer : str, optional
22
+ intermidiate layer name for VGG feature.
23
+ Now we support layer names:
24
+ `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
25
+ """
26
+
27
+ def __init__(self,
28
+ weight_sp: float = 0.1,
29
+ band_width: float = 0.5,
30
+ loss_type: str = 'cosine',
31
+ use_vgg: bool = False,
32
+ vgg_layer: str = 'relu3_4'):
33
+
34
+ super(ContextualBilateralLoss, self).__init__()
35
+
36
+ assert band_width > 0, 'band_width parameter must be positive.'
37
+ assert loss_type in LOSS_TYPES,\
38
+ f'select a loss type from {LOSS_TYPES}.'
39
+
40
+ self.band_width = band_width
41
+
42
+ if use_vgg:
43
+ self.vgg_model = VGG19()
44
+ self.vgg_layer = vgg_layer
45
+ self.register_buffer(
46
+ name='vgg_mean',
47
+ tensor=torch.tensor(
48
+ [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
49
+ )
50
+ self.register_buffer(
51
+ name='vgg_std',
52
+ tensor=torch.tensor(
53
+ [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
54
+ )
55
+
56
+ def forward(self, x, y):
57
+ if hasattr(self, 'vgg_model'):
58
+ assert x.shape[1] == 3 and y.shape[1] == 3,\
59
+ 'VGG model takes 3 chennel images.'
60
+
61
+ # normalization
62
+ x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
63
+ y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
64
+
65
+ # picking up vgg feature maps
66
+ x = getattr(self.vgg_model(x), self.vgg_layer)
67
+ y = getattr(self.vgg_model(y), self.vgg_layer)
68
+
69
+ return F.contextual_bilateral_loss(x, y, self.band_width)
Time_TravelRephotography/losses/contextual_loss/modules/vgg.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch.nn as nn
4
+ import torchvision.models.vgg as vgg
5
+
6
+
7
+ class VGG19(nn.Module):
8
+ def __init__(self, requires_grad=False):
9
+ super(VGG19, self).__init__()
10
+ vgg_pretrained_features = vgg.vgg19(pretrained=True).features
11
+ self.slice1 = nn.Sequential()
12
+ self.slice2 = nn.Sequential()
13
+ self.slice3 = nn.Sequential()
14
+ self.slice4 = nn.Sequential()
15
+ self.slice5 = nn.Sequential()
16
+ for x in range(4):
17
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
18
+ for x in range(4, 9):
19
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
20
+ for x in range(9, 18):
21
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
22
+ for x in range(18, 27):
23
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
24
+ for x in range(27, 36):
25
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
26
+ if not requires_grad:
27
+ for param in self.parameters():
28
+ param.requires_grad = False
29
+
30
+ def forward(self, X):
31
+ h = self.slice1(X)
32
+ h_relu1_2 = h
33
+ h = self.slice2(h)
34
+ h_relu2_2 = h
35
+ h = self.slice3(h)
36
+ h_relu3_4 = h
37
+ h = self.slice4(h)
38
+ h_relu4_4 = h
39
+ h = self.slice5(h)
40
+ h_relu5_4 = h
41
+
42
+ vgg_outputs = namedtuple(
43
+ "VggOutputs", ['relu1_2', 'relu2_2',
44
+ 'relu3_4', 'relu4_4', 'relu5_4'])
45
+ out = vgg_outputs(h_relu1_2, h_relu2_2,
46
+ h_relu3_4, h_relu4_4, h_relu5_4)
47
+
48
+ return out