Spaces:
Build error
Build error
Upload 17 files
Browse files- net/CMSFFT.py +377 -0
- net/IntmdSequential.py +19 -0
- net/PositionalEncoding.py +35 -0
- net/SGFMT.py +126 -0
- net/Transformer.py +126 -0
- net/Ushape_Trans.py +378 -0
- net/__pycache__/CMSFFT.cpython-37.pyc +0 -0
- net/__pycache__/CTrans.cpython-37.pyc +0 -0
- net/__pycache__/IntmdSequential.cpython-37.pyc +0 -0
- net/__pycache__/PositionalEncoding.cpython-37.pyc +0 -0
- net/__pycache__/SGFMT.cpython-37.pyc +0 -0
- net/__pycache__/Transformer.cpython-37.pyc +0 -0
- net/__pycache__/Ushape_Trans.cpython-37.pyc +0 -0
- net/__pycache__/block.cpython-37.pyc +0 -0
- net/__pycache__/utils.cpython-37.pyc +0 -0
- net/block.py +477 -0
- net/utils.py +86 -0
net/CMSFFT.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : Lintao Peng
|
3 |
+
# @File : CMSFFT.py
|
4 |
+
# coding=utf-8
|
5 |
+
# Design based on the CTrans
|
6 |
+
from __future__ import absolute_import
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import print_function
|
9 |
+
import copy
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
from torch.nn import Dropout, Softmax, Conv2d, LayerNorm
|
16 |
+
from torch.nn.modules.utils import _pair
|
17 |
+
|
18 |
+
|
19 |
+
#KV_size = 480
|
20 |
+
#transformer.num_heads = 4
|
21 |
+
#transformer.num_layers = 4
|
22 |
+
#expand_ratio = 4
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
#线性编码
|
27 |
+
class Channel_Embeddings(nn.Module):
|
28 |
+
"""Construct the embeddings from patch, position embeddings.
|
29 |
+
"""
|
30 |
+
def __init__(self, patchsize, img_size, in_channels):
|
31 |
+
super().__init__()
|
32 |
+
img_size = _pair(img_size)
|
33 |
+
patch_size = _pair(patchsize)
|
34 |
+
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
35 |
+
|
36 |
+
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
37 |
+
out_channels=in_channels,
|
38 |
+
kernel_size=patch_size,
|
39 |
+
stride=patch_size)
|
40 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))
|
41 |
+
self.dropout = Dropout(0.1)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
if x is None:
|
45 |
+
return None
|
46 |
+
x = self.patch_embeddings(x) # (B, hidden,n_patches^(1/2), n_patches^(1/2))
|
47 |
+
x = x.flatten(2)
|
48 |
+
x = x.transpose(-1, -2) # (B, n_patches, hidden)
|
49 |
+
embeddings = x + self.position_embeddings
|
50 |
+
embeddings = self.dropout(embeddings)
|
51 |
+
return embeddings
|
52 |
+
|
53 |
+
|
54 |
+
#特征重组
|
55 |
+
class Reconstruct(nn.Module):
|
56 |
+
def __init__(self, in_channels, out_channels, kernel_size, scale_factor):
|
57 |
+
super(Reconstruct, self).__init__()
|
58 |
+
if kernel_size == 3:
|
59 |
+
padding = 1
|
60 |
+
else:
|
61 |
+
padding = 0
|
62 |
+
self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, padding=padding)
|
63 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
64 |
+
self.activation = nn.ReLU(inplace=True)
|
65 |
+
self.scale_factor = scale_factor
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if x is None:
|
69 |
+
return None
|
70 |
+
|
71 |
+
# reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
72 |
+
B, n_patch, hidden = x.size()
|
73 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
74 |
+
x = x.permute(0, 2, 1)
|
75 |
+
x = x.contiguous().view(B, hidden, h, w)
|
76 |
+
x = nn.Upsample(scale_factor=self.scale_factor)(x)
|
77 |
+
|
78 |
+
out = self.conv(x)
|
79 |
+
out = self.norm(out)
|
80 |
+
out = self.activation(out)
|
81 |
+
return out
|
82 |
+
|
83 |
+
class Attention_org(nn.Module):
|
84 |
+
def __init__(self, vis,channel_num, KV_size=480, num_heads=4):
|
85 |
+
super(Attention_org, self).__init__()
|
86 |
+
self.vis = vis
|
87 |
+
self.KV_size = KV_size
|
88 |
+
self.channel_num = channel_num
|
89 |
+
self.num_attention_heads = num_heads
|
90 |
+
|
91 |
+
self.query1 = nn.ModuleList()
|
92 |
+
self.query2 = nn.ModuleList()
|
93 |
+
self.query3 = nn.ModuleList()
|
94 |
+
self.query4 = nn.ModuleList()
|
95 |
+
self.key = nn.ModuleList()
|
96 |
+
self.value = nn.ModuleList()
|
97 |
+
|
98 |
+
for _ in range(num_heads):
|
99 |
+
query1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
|
100 |
+
query2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
|
101 |
+
query3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
|
102 |
+
query4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
|
103 |
+
key = nn.Linear( self.KV_size, self.KV_size, bias=False)
|
104 |
+
value = nn.Linear(self.KV_size, self.KV_size, bias=False)
|
105 |
+
#把所有的值都重新复制一遍,deepcopy为深复制,完全脱离原来的值,即将被复制对象完全再复制一遍作为独立的新个体单独存在
|
106 |
+
self.query1.append(copy.deepcopy(query1))
|
107 |
+
self.query2.append(copy.deepcopy(query2))
|
108 |
+
self.query3.append(copy.deepcopy(query3))
|
109 |
+
self.query4.append(copy.deepcopy(query4))
|
110 |
+
self.key.append(copy.deepcopy(key))
|
111 |
+
self.value.append(copy.deepcopy(value))
|
112 |
+
self.psi = nn.InstanceNorm2d(self.num_attention_heads)
|
113 |
+
self.softmax = Softmax(dim=3)
|
114 |
+
self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False)
|
115 |
+
self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False)
|
116 |
+
self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False)
|
117 |
+
self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False)
|
118 |
+
self.attn_dropout = Dropout(0.1)
|
119 |
+
self.proj_dropout = Dropout(0.1)
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def forward(self, emb1,emb2,emb3,emb4, emb_all):
|
124 |
+
multi_head_Q1_list = []
|
125 |
+
multi_head_Q2_list = []
|
126 |
+
multi_head_Q3_list = []
|
127 |
+
multi_head_Q4_list = []
|
128 |
+
multi_head_K_list = []
|
129 |
+
multi_head_V_list = []
|
130 |
+
if emb1 is not None:
|
131 |
+
for query1 in self.query1:
|
132 |
+
Q1 = query1(emb1)
|
133 |
+
multi_head_Q1_list.append(Q1)
|
134 |
+
if emb2 is not None:
|
135 |
+
for query2 in self.query2:
|
136 |
+
Q2 = query2(emb2)
|
137 |
+
multi_head_Q2_list.append(Q2)
|
138 |
+
if emb3 is not None:
|
139 |
+
for query3 in self.query3:
|
140 |
+
Q3 = query3(emb3)
|
141 |
+
multi_head_Q3_list.append(Q3)
|
142 |
+
if emb4 is not None:
|
143 |
+
for query4 in self.query4:
|
144 |
+
Q4 = query4(emb4)
|
145 |
+
multi_head_Q4_list.append(Q4)
|
146 |
+
for key in self.key:
|
147 |
+
K = key(emb_all)
|
148 |
+
multi_head_K_list.append(K)
|
149 |
+
for value in self.value:
|
150 |
+
V = value(emb_all)
|
151 |
+
multi_head_V_list.append(V)
|
152 |
+
# print(len(multi_head_Q4_list))
|
153 |
+
|
154 |
+
multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None
|
155 |
+
multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None
|
156 |
+
multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None
|
157 |
+
multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None
|
158 |
+
multi_head_K = torch.stack(multi_head_K_list, dim=1)
|
159 |
+
multi_head_V = torch.stack(multi_head_V_list, dim=1)
|
160 |
+
|
161 |
+
multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None
|
162 |
+
multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None
|
163 |
+
multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None
|
164 |
+
multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None
|
165 |
+
|
166 |
+
attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None
|
167 |
+
attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None
|
168 |
+
attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None
|
169 |
+
attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None
|
170 |
+
|
171 |
+
attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None
|
172 |
+
attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None
|
173 |
+
attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None
|
174 |
+
attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None
|
175 |
+
|
176 |
+
attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None
|
177 |
+
attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None
|
178 |
+
attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None
|
179 |
+
attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None
|
180 |
+
# print(attention_probs4.size())
|
181 |
+
|
182 |
+
if self.vis:
|
183 |
+
weights = []
|
184 |
+
weights.append(attention_probs1.mean(1))
|
185 |
+
weights.append(attention_probs2.mean(1))
|
186 |
+
weights.append(attention_probs3.mean(1))
|
187 |
+
weights.append(attention_probs4.mean(1))
|
188 |
+
else: weights=None
|
189 |
+
|
190 |
+
attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None
|
191 |
+
attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None
|
192 |
+
attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None
|
193 |
+
attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None
|
194 |
+
|
195 |
+
multi_head_V = multi_head_V.transpose(-1, -2)
|
196 |
+
context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None
|
197 |
+
context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None
|
198 |
+
context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None
|
199 |
+
context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None
|
200 |
+
|
201 |
+
context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None
|
202 |
+
context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None
|
203 |
+
context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None
|
204 |
+
context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None
|
205 |
+
context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None
|
206 |
+
context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None
|
207 |
+
context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None
|
208 |
+
context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None
|
209 |
+
|
210 |
+
O1 = self.out1(context_layer1) if emb1 is not None else None
|
211 |
+
O2 = self.out2(context_layer2) if emb2 is not None else None
|
212 |
+
O3 = self.out3(context_layer3) if emb3 is not None else None
|
213 |
+
O4 = self.out4(context_layer4) if emb4 is not None else None
|
214 |
+
O1 = self.proj_dropout(O1) if emb1 is not None else None
|
215 |
+
O2 = self.proj_dropout(O2) if emb2 is not None else None
|
216 |
+
O3 = self.proj_dropout(O3) if emb3 is not None else None
|
217 |
+
O4 = self.proj_dropout(O4) if emb4 is not None else None
|
218 |
+
return O1,O2,O3,O4, weights
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
class Mlp(nn.Module):
|
224 |
+
def __init__(self, in_channel, mlp_channel):
|
225 |
+
super(Mlp, self).__init__()
|
226 |
+
self.fc1 = nn.Linear(in_channel, mlp_channel)
|
227 |
+
self.fc2 = nn.Linear(mlp_channel, in_channel)
|
228 |
+
self.act_fn = nn.GELU()
|
229 |
+
self.dropout = Dropout(0.0)
|
230 |
+
self._init_weights()
|
231 |
+
|
232 |
+
def _init_weights(self):
|
233 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
234 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
235 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
236 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
x = self.fc1(x)
|
240 |
+
x = self.act_fn(x)
|
241 |
+
x = self.dropout(x)
|
242 |
+
x = self.fc2(x)
|
243 |
+
x = self.dropout(x)
|
244 |
+
return x
|
245 |
+
|
246 |
+
class Block_ViT(nn.Module):
|
247 |
+
def __init__(self, vis, channel_num, expand_ratio=4,KV_size=480):
|
248 |
+
super(Block_ViT, self).__init__()
|
249 |
+
expand_ratio = 4
|
250 |
+
self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
251 |
+
self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
252 |
+
self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
253 |
+
self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
254 |
+
self.attn_norm = LayerNorm(KV_size,eps=1e-6)
|
255 |
+
self.channel_attn = Attention_org(vis, channel_num)
|
256 |
+
|
257 |
+
self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
258 |
+
self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
259 |
+
self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
260 |
+
self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
261 |
+
self.ffn1 = Mlp(channel_num[0],channel_num[0]*expand_ratio)
|
262 |
+
self.ffn2 = Mlp(channel_num[1],channel_num[1]*expand_ratio)
|
263 |
+
self.ffn3 = Mlp(channel_num[2],channel_num[2]*expand_ratio)
|
264 |
+
self.ffn4 = Mlp(channel_num[3],channel_num[3]*expand_ratio)
|
265 |
+
|
266 |
+
|
267 |
+
def forward(self, emb1,emb2,emb3,emb4):
|
268 |
+
embcat = []
|
269 |
+
org1 = emb1
|
270 |
+
org2 = emb2
|
271 |
+
org3 = emb3
|
272 |
+
org4 = emb4
|
273 |
+
for i in range(4):
|
274 |
+
var_name = "emb"+str(i+1) #emb1,emb2,emb3,emb4
|
275 |
+
tmp_var = locals()[var_name]
|
276 |
+
if tmp_var is not None:
|
277 |
+
embcat.append(tmp_var)
|
278 |
+
|
279 |
+
emb_all = torch.cat(embcat,dim=2)
|
280 |
+
cx1 = self.attn_norm1(emb1) if emb1 is not None else None
|
281 |
+
cx2 = self.attn_norm2(emb2) if emb2 is not None else None
|
282 |
+
cx3 = self.attn_norm3(emb3) if emb3 is not None else None
|
283 |
+
cx4 = self.attn_norm4(emb4) if emb4 is not None else None
|
284 |
+
emb_all = self.attn_norm(emb_all)
|
285 |
+
cx1,cx2,cx3,cx4, weights = self.channel_attn(cx1,cx2,cx3,cx4,emb_all)
|
286 |
+
#残差
|
287 |
+
cx1 = org1 + cx1 if emb1 is not None else None
|
288 |
+
cx2 = org2 + cx2 if emb2 is not None else None
|
289 |
+
cx3 = org3 + cx3 if emb3 is not None else None
|
290 |
+
cx4 = org4 + cx4 if emb4 is not None else None
|
291 |
+
|
292 |
+
org1 = cx1
|
293 |
+
org2 = cx2
|
294 |
+
org3 = cx3
|
295 |
+
org4 = cx4
|
296 |
+
x1 = self.ffn_norm1(cx1) if emb1 is not None else None
|
297 |
+
x2 = self.ffn_norm2(cx2) if emb2 is not None else None
|
298 |
+
x3 = self.ffn_norm3(cx3) if emb3 is not None else None
|
299 |
+
x4 = self.ffn_norm4(cx4) if emb4 is not None else None
|
300 |
+
x1 = self.ffn1(x1) if emb1 is not None else None
|
301 |
+
x2 = self.ffn2(x2) if emb2 is not None else None
|
302 |
+
x3 = self.ffn3(x3) if emb3 is not None else None
|
303 |
+
x4 = self.ffn4(x4) if emb4 is not None else None
|
304 |
+
#残差
|
305 |
+
x1 = x1 + org1 if emb1 is not None else None
|
306 |
+
x2 = x2 + org2 if emb2 is not None else None
|
307 |
+
x3 = x3 + org3 if emb3 is not None else None
|
308 |
+
x4 = x4 + org4 if emb4 is not None else None
|
309 |
+
|
310 |
+
return x1, x2, x3, x4, weights
|
311 |
+
|
312 |
+
|
313 |
+
class Encoder(nn.Module):
|
314 |
+
def __init__(self, vis, channel_num, num_layers=4):
|
315 |
+
super(Encoder, self).__init__()
|
316 |
+
self.vis = vis
|
317 |
+
self.layer = nn.ModuleList()
|
318 |
+
self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6)
|
319 |
+
self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6)
|
320 |
+
self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6)
|
321 |
+
self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6)
|
322 |
+
for _ in range(num_layers):
|
323 |
+
layer = Block_ViT(vis, channel_num)
|
324 |
+
self.layer.append(copy.deepcopy(layer))
|
325 |
+
|
326 |
+
def forward(self, emb1,emb2,emb3,emb4):
|
327 |
+
attn_weights = []
|
328 |
+
for layer_block in self.layer:
|
329 |
+
emb1,emb2,emb3,emb4, weights = layer_block(emb1,emb2,emb3,emb4)
|
330 |
+
if self.vis:
|
331 |
+
attn_weights.append(weights)
|
332 |
+
emb1 = self.encoder_norm1(emb1) if emb1 is not None else None
|
333 |
+
emb2 = self.encoder_norm2(emb2) if emb2 is not None else None
|
334 |
+
emb3 = self.encoder_norm3(emb3) if emb3 is not None else None
|
335 |
+
emb4 = self.encoder_norm4(emb4) if emb4 is not None else None
|
336 |
+
return emb1,emb2,emb3,emb4, attn_weights
|
337 |
+
|
338 |
+
|
339 |
+
class ChannelTransformer(nn.Module):
|
340 |
+
def __init__(self, vis=False, img_size=256, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]):
|
341 |
+
super().__init__()
|
342 |
+
|
343 |
+
self.patchSize_1 = patchSize[0]
|
344 |
+
self.patchSize_2 = patchSize[1]
|
345 |
+
self.patchSize_3 = patchSize[2]
|
346 |
+
self.patchSize_4 = patchSize[3]
|
347 |
+
self.embeddings_1 = Channel_Embeddings(self.patchSize_1, img_size=img_size, in_channels=channel_num[0])
|
348 |
+
self.embeddings_2 = Channel_Embeddings(self.patchSize_2, img_size=img_size//2, in_channels=channel_num[1])
|
349 |
+
self.embeddings_3 = Channel_Embeddings(self.patchSize_3, img_size=img_size//4, in_channels=channel_num[2])
|
350 |
+
self.embeddings_4 = Channel_Embeddings(self.patchSize_4, img_size=img_size//8, in_channels=channel_num[3])
|
351 |
+
self.encoder = Encoder( vis, channel_num)
|
352 |
+
|
353 |
+
self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=(self.patchSize_1,self.patchSize_1))
|
354 |
+
self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=(self.patchSize_2,self.patchSize_2))
|
355 |
+
self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=(self.patchSize_3,self.patchSize_3))
|
356 |
+
self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=(self.patchSize_4,self.patchSize_4))
|
357 |
+
|
358 |
+
def forward(self,en1,en2,en3,en4):
|
359 |
+
|
360 |
+
emb1 = self.embeddings_1(en1)
|
361 |
+
emb2 = self.embeddings_2(en2)
|
362 |
+
emb3 = self.embeddings_3(en3)
|
363 |
+
emb4 = self.embeddings_4(en4)
|
364 |
+
|
365 |
+
encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden)
|
366 |
+
x1 = self.reconstruct_1(encoded1) if en1 is not None else None
|
367 |
+
x2 = self.reconstruct_2(encoded2) if en2 is not None else None
|
368 |
+
x3 = self.reconstruct_3(encoded3) if en3 is not None else None
|
369 |
+
x4 = self.reconstruct_4(encoded4) if en4 is not None else None
|
370 |
+
|
371 |
+
x1 = x1 + en1 if en1 is not None else None
|
372 |
+
x2 = x2 + en2 if en2 is not None else None
|
373 |
+
x3 = x3 + en3 if en3 is not None else None
|
374 |
+
x4 = x4 + en4 if en4 is not None else None
|
375 |
+
|
376 |
+
return x1, x2, x3, x4, attn_weights
|
377 |
+
|
net/IntmdSequential.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class IntermediateSequential(nn.Sequential):
|
5 |
+
def __init__(self, *args, return_intermediate=False):
|
6 |
+
super().__init__(*args)
|
7 |
+
self.return_intermediate = return_intermediate
|
8 |
+
|
9 |
+
def forward(self, input):
|
10 |
+
if not self.return_intermediate:
|
11 |
+
return super().forward(input)
|
12 |
+
|
13 |
+
intermediate_outputs = {}
|
14 |
+
output = input
|
15 |
+
for name, module in self.named_children():
|
16 |
+
output = intermediate_outputs[name] = module(output)
|
17 |
+
|
18 |
+
return output, intermediate_outputs
|
19 |
+
|
net/PositionalEncoding.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
#实现了位置编码
|
6 |
+
class FixedPositionalEncoding(nn.Module):
|
7 |
+
def __init__(self, embedding_dim, max_length=512):
|
8 |
+
super(FixedPositionalEncoding, self).__init__()
|
9 |
+
|
10 |
+
pe = torch.zeros(max_length, embedding_dim)
|
11 |
+
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
|
12 |
+
div_term = torch.exp(
|
13 |
+
torch.arange(0, embedding_dim, 2).float()
|
14 |
+
* (-torch.log(torch.tensor(10000.0)) / embedding_dim)
|
15 |
+
)
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
19 |
+
self.register_buffer('pe', pe)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[: x.size(0), :]
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class LearnedPositionalEncoding(nn.Module):
|
27 |
+
def __init__(self, max_position_embeddings, embedding_dim, seq_length):
|
28 |
+
super(LearnedPositionalEncoding, self).__init__()
|
29 |
+
|
30 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, 256, 512)) #8x
|
31 |
+
|
32 |
+
def forward(self, x, position_ids=None):
|
33 |
+
|
34 |
+
position_embeddings = self.position_embeddings
|
35 |
+
return x + position_embeddings
|
net/SGFMT.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : Lintao Peng
|
3 |
+
# @File : SGFMT.py
|
4 |
+
# coding=utf-8
|
5 |
+
# Design based on the Vit
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
from net.IntmdSequential import IntermediateSequential
|
9 |
+
|
10 |
+
|
11 |
+
#实现了自注意力机制,相当于unet的bottleneck层
|
12 |
+
class SelfAttention(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.num_heads = heads
|
18 |
+
head_dim = dim // heads
|
19 |
+
self.scale = qk_scale or head_dim ** -0.5
|
20 |
+
|
21 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
22 |
+
self.attn_drop = nn.Dropout(dropout_rate)
|
23 |
+
self.proj = nn.Linear(dim, dim)
|
24 |
+
self.proj_drop = nn.Dropout(dropout_rate)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
B, N, C = x.shape
|
28 |
+
qkv = (
|
29 |
+
self.qkv(x)
|
30 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
31 |
+
.permute(2, 0, 3, 1, 4)
|
32 |
+
)
|
33 |
+
q, k, v = (
|
34 |
+
qkv[0],
|
35 |
+
qkv[1],
|
36 |
+
qkv[2],
|
37 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
38 |
+
|
39 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
40 |
+
attn = attn.softmax(dim=-1)
|
41 |
+
attn = self.attn_drop(attn)
|
42 |
+
|
43 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
44 |
+
x = self.proj(x)
|
45 |
+
x = self.proj_drop(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class Residual(nn.Module):
|
50 |
+
def __init__(self, fn):
|
51 |
+
super().__init__()
|
52 |
+
self.fn = fn
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.fn(x) + x
|
56 |
+
|
57 |
+
|
58 |
+
class PreNorm(nn.Module):
|
59 |
+
def __init__(self, dim, fn):
|
60 |
+
super().__init__()
|
61 |
+
self.norm = nn.LayerNorm(dim)
|
62 |
+
self.fn = fn
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return self.fn(self.norm(x))
|
66 |
+
|
67 |
+
|
68 |
+
class PreNormDrop(nn.Module):
|
69 |
+
def __init__(self, dim, dropout_rate, fn):
|
70 |
+
super().__init__()
|
71 |
+
self.norm = nn.LayerNorm(dim)
|
72 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
73 |
+
self.fn = fn
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
return self.dropout(self.fn(self.norm(x)))
|
77 |
+
|
78 |
+
|
79 |
+
class FeedForward(nn.Module):
|
80 |
+
def __init__(self, dim, hidden_dim, dropout_rate):
|
81 |
+
super().__init__()
|
82 |
+
self.net = nn.Sequential(
|
83 |
+
nn.Linear(dim, hidden_dim),
|
84 |
+
nn.GELU(),
|
85 |
+
nn.Dropout(p=dropout_rate),
|
86 |
+
nn.Linear(hidden_dim, dim),
|
87 |
+
nn.Dropout(p=dropout_rate),
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
return self.net(x)
|
92 |
+
|
93 |
+
|
94 |
+
class TransformerModel(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
dim, #512
|
98 |
+
depth, #4
|
99 |
+
heads, #8
|
100 |
+
mlp_dim, #4096
|
101 |
+
dropout_rate=0.1,
|
102 |
+
attn_dropout_rate=0.1,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
layers = []
|
106 |
+
for _ in range(depth):
|
107 |
+
layers.extend(
|
108 |
+
[
|
109 |
+
Residual(
|
110 |
+
PreNormDrop(
|
111 |
+
dim,
|
112 |
+
dropout_rate,
|
113 |
+
SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
|
114 |
+
)
|
115 |
+
),
|
116 |
+
Residual(
|
117 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
|
118 |
+
),
|
119 |
+
]
|
120 |
+
)
|
121 |
+
# dim = dim / 2
|
122 |
+
self.net = IntermediateSequential(*layers)
|
123 |
+
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
return self.net(x)
|
net/Transformer.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : Lintao Peng
|
3 |
+
# @File : SGFMT.py
|
4 |
+
# coding=utf-8
|
5 |
+
# Design based on the Vit
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
from net.IntmdSequential import IntermediateSequential
|
9 |
+
|
10 |
+
|
11 |
+
#实现了自注意力机制,相当于unet的bottleneck层
|
12 |
+
class SelfAttention(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.num_heads = heads
|
18 |
+
head_dim = dim // heads
|
19 |
+
self.scale = qk_scale or head_dim ** -0.5
|
20 |
+
|
21 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
22 |
+
self.attn_drop = nn.Dropout(dropout_rate)
|
23 |
+
self.proj = nn.Linear(dim, dim)
|
24 |
+
self.proj_drop = nn.Dropout(dropout_rate)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
B, N, C = x.shape
|
28 |
+
qkv = (
|
29 |
+
self.qkv(x)
|
30 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
31 |
+
.permute(2, 0, 3, 1, 4)
|
32 |
+
)
|
33 |
+
q, k, v = (
|
34 |
+
qkv[0],
|
35 |
+
qkv[1],
|
36 |
+
qkv[2],
|
37 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
38 |
+
|
39 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
40 |
+
attn = attn.softmax(dim=-1)
|
41 |
+
attn = self.attn_drop(attn)
|
42 |
+
|
43 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
44 |
+
x = self.proj(x)
|
45 |
+
x = self.proj_drop(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class Residual(nn.Module):
|
50 |
+
def __init__(self, fn):
|
51 |
+
super().__init__()
|
52 |
+
self.fn = fn
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.fn(x) + x
|
56 |
+
|
57 |
+
|
58 |
+
class PreNorm(nn.Module):
|
59 |
+
def __init__(self, dim, fn):
|
60 |
+
super().__init__()
|
61 |
+
self.norm = nn.LayerNorm(dim)
|
62 |
+
self.fn = fn
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return self.fn(self.norm(x))
|
66 |
+
|
67 |
+
|
68 |
+
class PreNormDrop(nn.Module):
|
69 |
+
def __init__(self, dim, dropout_rate, fn):
|
70 |
+
super().__init__()
|
71 |
+
self.norm = nn.LayerNorm(dim)
|
72 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
73 |
+
self.fn = fn
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
return self.dropout(self.fn(self.norm(x)))
|
77 |
+
|
78 |
+
|
79 |
+
class FeedForward(nn.Module):
|
80 |
+
def __init__(self, dim, hidden_dim, dropout_rate):
|
81 |
+
super().__init__()
|
82 |
+
self.net = nn.Sequential(
|
83 |
+
nn.Linear(dim, hidden_dim),
|
84 |
+
nn.GELU(),
|
85 |
+
nn.Dropout(p=dropout_rate),
|
86 |
+
nn.Linear(hidden_dim, dim),
|
87 |
+
nn.Dropout(p=dropout_rate),
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
return self.net(x)
|
92 |
+
|
93 |
+
|
94 |
+
class TransformerModel(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
dim, #512
|
98 |
+
depth, #4
|
99 |
+
heads, #8
|
100 |
+
mlp_dim, #4096
|
101 |
+
dropout_rate=0.1,
|
102 |
+
attn_dropout_rate=0.1,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
layers = []
|
106 |
+
for _ in range(depth):
|
107 |
+
layers.extend(
|
108 |
+
[
|
109 |
+
Residual(
|
110 |
+
PreNormDrop(
|
111 |
+
dim,
|
112 |
+
dropout_rate,
|
113 |
+
SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
|
114 |
+
)
|
115 |
+
),
|
116 |
+
Residual(
|
117 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
|
118 |
+
),
|
119 |
+
]
|
120 |
+
)
|
121 |
+
# dim = dim / 2
|
122 |
+
self.net = IntermediateSequential(*layers)
|
123 |
+
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
return self.net(x)
|
net/Ushape_Trans.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : Lintao Peng
|
3 |
+
# @File : Ushape_Trans.py
|
4 |
+
# coding=utf-8
|
5 |
+
# Design based on the pix2pix
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch
|
10 |
+
import datetime
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import timeit
|
14 |
+
import copy
|
15 |
+
import numpy as np
|
16 |
+
from torch.nn import ModuleList
|
17 |
+
from torch.nn import Conv2d
|
18 |
+
from torch.nn import LeakyReLU
|
19 |
+
from net.block import *
|
20 |
+
from net.block import _equalized_conv2d
|
21 |
+
from net.SGFMT import TransformerModel
|
22 |
+
from net.PositionalEncoding import FixedPositionalEncoding,LearnedPositionalEncoding
|
23 |
+
from net.CMSFFT import ChannelTransformer
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
##权重初始化
|
32 |
+
def weights_init_normal(m):
|
33 |
+
classname = m.__class__.__name__
|
34 |
+
if classname.find("Conv") != -1:
|
35 |
+
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
|
36 |
+
elif classname.find("BatchNorm2d") != -1:
|
37 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
38 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
class Generator(nn.Module):
|
46 |
+
"""
|
47 |
+
MSG-Unet-GAN的生成器部分
|
48 |
+
"""
|
49 |
+
def __init__(self,
|
50 |
+
img_dim=256,
|
51 |
+
patch_dim=16,
|
52 |
+
embedding_dim=512,
|
53 |
+
num_channels=3,
|
54 |
+
num_heads=8,
|
55 |
+
num_layers=4,
|
56 |
+
hidden_dim=256,
|
57 |
+
dropout_rate=0.0,
|
58 |
+
attn_dropout_rate=0.0,
|
59 |
+
in_ch=3,
|
60 |
+
out_ch=3,
|
61 |
+
conv_patch_representation=True,
|
62 |
+
positional_encoding_type="learned",
|
63 |
+
use_eql=True):
|
64 |
+
super(Generator, self).__init__()
|
65 |
+
assert embedding_dim % num_heads == 0
|
66 |
+
assert img_dim % patch_dim == 0
|
67 |
+
|
68 |
+
self.out_ch=out_ch #输出通道数
|
69 |
+
self.in_ch=in_ch #输入通道数
|
70 |
+
self.img_dim = img_dim #输入图片尺寸
|
71 |
+
self.embedding_dim = embedding_dim #512
|
72 |
+
self.num_heads = num_heads #多头注意力中头的数量
|
73 |
+
self.patch_dim = patch_dim #每个patch的尺寸
|
74 |
+
self.num_channels = num_channels #图片通道数?
|
75 |
+
self.dropout_rate = dropout_rate #drop-out比率
|
76 |
+
self.attn_dropout_rate = attn_dropout_rate #注意力模块的dropout比率
|
77 |
+
self.conv_patch_representation = conv_patch_representation #True
|
78 |
+
|
79 |
+
self.num_patches = int((img_dim // patch_dim) ** 2) #将三通道图片分成多少块
|
80 |
+
self.seq_length = self.num_patches #每个sequence的长度为patches的大小
|
81 |
+
self.flatten_dim = 128 * num_channels #128*3=384
|
82 |
+
|
83 |
+
#线性编码
|
84 |
+
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
|
85 |
+
#位置编码
|
86 |
+
if positional_encoding_type == "learned":
|
87 |
+
self.position_encoding = LearnedPositionalEncoding(
|
88 |
+
self.seq_length, self.embedding_dim, self.seq_length
|
89 |
+
)
|
90 |
+
elif positional_encoding_type == "fixed":
|
91 |
+
self.position_encoding = FixedPositionalEncoding(
|
92 |
+
self.embedding_dim,
|
93 |
+
)
|
94 |
+
|
95 |
+
self.pe_dropout = nn.Dropout(p=self.dropout_rate)
|
96 |
+
|
97 |
+
self.transformer = TransformerModel(
|
98 |
+
embedding_dim, #512
|
99 |
+
num_layers, #4
|
100 |
+
num_heads, #8
|
101 |
+
hidden_dim, #4096
|
102 |
+
|
103 |
+
self.dropout_rate,
|
104 |
+
self.attn_dropout_rate,
|
105 |
+
)
|
106 |
+
|
107 |
+
#layer Norm
|
108 |
+
self.pre_head_ln = nn.LayerNorm(embedding_dim)
|
109 |
+
|
110 |
+
if self.conv_patch_representation:
|
111 |
+
|
112 |
+
self.Conv_x = nn.Conv2d(
|
113 |
+
256,
|
114 |
+
self.embedding_dim, #512
|
115 |
+
kernel_size=3,
|
116 |
+
stride=1,
|
117 |
+
padding=1
|
118 |
+
)
|
119 |
+
|
120 |
+
self.bn = nn.BatchNorm2d(256)
|
121 |
+
self.relu = nn.ReLU(inplace=True)
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
#modulelist
|
126 |
+
self.rgb_to_feature=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)])
|
127 |
+
self.feature_to_rgb=ModuleList([to_rgb(32),to_rgb(64),to_rgb(128),to_rgb(256)])
|
128 |
+
|
129 |
+
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
130 |
+
self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
131 |
+
self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
132 |
+
self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
133 |
+
self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
134 |
+
|
135 |
+
self.Conv1=conv_block(self.in_ch, 16)
|
136 |
+
self.Conv1_1 = conv_block(16, 32)
|
137 |
+
self.Conv2 = conv_block(32, 32)
|
138 |
+
self.Conv2_1 = conv_block(32, 64)
|
139 |
+
self.Conv3 = conv_block(64,64)
|
140 |
+
self.Conv3_1 = conv_block(64,128)
|
141 |
+
self.Conv4 = conv_block(128,128)
|
142 |
+
self.Conv4_1 = conv_block(128,256)
|
143 |
+
|
144 |
+
self.Conv5 = conv_block(512,256)
|
145 |
+
|
146 |
+
#self.Conv_x = conv_block(256,512)
|
147 |
+
self.mtc = ChannelTransformer(channel_num=[32,64,128,256],
|
148 |
+
patchSize=[32, 16, 8, 4])
|
149 |
+
|
150 |
+
|
151 |
+
self.Up5 = up_conv(256, 256)
|
152 |
+
self.coatt5 = CCA(F_g=256, F_x=256)
|
153 |
+
self.Up_conv5 = conv_block(512, 256)
|
154 |
+
self.Up_conv5_1 = conv_block(256, 256)
|
155 |
+
|
156 |
+
self.Up4 = up_conv(256, 128)
|
157 |
+
self.coatt4 = CCA(F_g=128, F_x=128)
|
158 |
+
self.Up_conv4 = conv_block(256, 128)
|
159 |
+
self.Up_conv4_1 = conv_block(128, 128)
|
160 |
+
|
161 |
+
self.Up3 = up_conv(128, 64)
|
162 |
+
self.coatt3 = CCA(F_g=64, F_x=64)
|
163 |
+
self.Up_conv3 = conv_block(128, 64)
|
164 |
+
self.Up_conv3_1 = conv_block(64, 64)
|
165 |
+
|
166 |
+
self.Up2 = up_conv(64, 32)
|
167 |
+
self.coatt2 = CCA(F_g=32, F_x=32)
|
168 |
+
self.Up_conv2 = conv_block(64, 32)
|
169 |
+
self.Up_conv2_1 = conv_block(32, 32)
|
170 |
+
|
171 |
+
self.Conv = nn.Conv2d(32, self.out_ch, kernel_size=1, stride=1, padding=0)
|
172 |
+
|
173 |
+
# self.active = torch.nn.Sigmoid()
|
174 |
+
#
|
175 |
+
def reshape_output(self,x): #将transformer的输出resize为原来的特征图尺寸
|
176 |
+
x = x.view(
|
177 |
+
x.size(0),
|
178 |
+
int(self.img_dim / self.patch_dim),
|
179 |
+
int(self.img_dim / self.patch_dim),
|
180 |
+
self.embedding_dim,
|
181 |
+
)#B,16,16,512
|
182 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
183 |
+
|
184 |
+
return x
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
#print(x.shape)
|
188 |
+
|
189 |
+
|
190 |
+
output=[]
|
191 |
+
|
192 |
+
x_1=self.Maxpool(x)
|
193 |
+
x_2=self.Maxpool(x_1)
|
194 |
+
x_3=self.Maxpool(x_2)
|
195 |
+
|
196 |
+
|
197 |
+
e1 = self.Conv1(x)
|
198 |
+
#print(e1.shape)
|
199 |
+
e1 = self.Conv1_1(e1)
|
200 |
+
e2 = self.Maxpool1(e1)
|
201 |
+
#32*128*128
|
202 |
+
|
203 |
+
x_1=self.rgb_to_feature[0](x_1)
|
204 |
+
#e2=torch.cat((x_1,e2), dim=1)
|
205 |
+
e2=x_1+e2
|
206 |
+
e2 = self.Conv2(e2)
|
207 |
+
e2 = self.Conv2_1(e2)
|
208 |
+
e3 = self.Maxpool2(e2)
|
209 |
+
#64*64*64
|
210 |
+
|
211 |
+
x_2=self.rgb_to_feature[1](x_2)
|
212 |
+
#e3=torch.cat((x_2,e3), dim=1)
|
213 |
+
e3=x_2+e3
|
214 |
+
e3 = self.Conv3(e3)
|
215 |
+
e3 = self.Conv3_1(e3)
|
216 |
+
e4 = self.Maxpool3(e3)
|
217 |
+
#128*32*32
|
218 |
+
|
219 |
+
x_3=self.rgb_to_feature[2](x_3)
|
220 |
+
#e4=torch.cat((x_3,e4), dim=1)
|
221 |
+
e4=x_3+e4
|
222 |
+
e4 = self.Conv4(e4)
|
223 |
+
e4 = self.Conv4_1(e4)
|
224 |
+
e5 = self.Maxpool4(e4)
|
225 |
+
#256*16*16
|
226 |
+
|
227 |
+
#channel-wise transformer-based attention
|
228 |
+
e1,e2,e3,e4,att_weights = self.mtc(e1,e2,e3,e4)
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
#spatial-wise transformer-based attention
|
234 |
+
residual=e5
|
235 |
+
#中间的隐变量
|
236 |
+
#conv_x应该接受256通道,输出512通道的中间隐变量
|
237 |
+
e5= self.bn(e5)
|
238 |
+
e5=self.relu(e5)
|
239 |
+
e5= self.Conv_x(e5) #out->512*16*16 shape->B,512,16,16
|
240 |
+
e5= e5.permute(0, 2, 3, 1).contiguous() # B,512,16,16->B,16,16,512
|
241 |
+
e5= e5.view(e5.size(0), -1, self.embedding_dim) #B,16,16,512->B,16*16,512 线性映射层
|
242 |
+
e5= self.position_encoding(e5) #位置编码
|
243 |
+
e5= self.pe_dropout(e5) #预dropout层
|
244 |
+
# apply transformer
|
245 |
+
e5= self.transformer(e5)
|
246 |
+
e5= self.pre_head_ln(e5)
|
247 |
+
e5= self.reshape_output(e5)#out->512*16*16 shape->B,512,16,16
|
248 |
+
e5=self.Conv5(e5) #out->256,16,16 shape->B,256,16,16
|
249 |
+
#residual是否要加bn和relu?
|
250 |
+
e5=e5+residual
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
d5 = self.Up5(e5)
|
255 |
+
e4_att = self.coatt5(g=d5, x=e4)
|
256 |
+
d5 = torch.cat((e4_att, d5), dim=1)
|
257 |
+
d5 = self.Up_conv5(d5)
|
258 |
+
d5 = self.Up_conv5_1(d5)
|
259 |
+
#256
|
260 |
+
out3=self.feature_to_rgb[3](d5)
|
261 |
+
output.append(out3)#32*32orH/8,W/8
|
262 |
+
|
263 |
+
d4 = self.Up4(d5)
|
264 |
+
e3_att = self.coatt4(g=d4, x=e3)
|
265 |
+
d4 = torch.cat((e3_att, d4), dim=1)
|
266 |
+
d4 = self.Up_conv4(d4)
|
267 |
+
d4 = self.Up_conv4_1(d4)
|
268 |
+
#128
|
269 |
+
out2=self.feature_to_rgb[2](d4)
|
270 |
+
output.append(out2)#64*64orH/4,W/4
|
271 |
+
|
272 |
+
d3 = self.Up3(d4)
|
273 |
+
e2_att = self.coatt3(g=d3, x=e2)
|
274 |
+
d3 = torch.cat((e2_att, d3), dim=1)
|
275 |
+
d3 = self.Up_conv3(d3)
|
276 |
+
d3 = self.Up_conv3_1(d3)
|
277 |
+
#64
|
278 |
+
out1=self.feature_to_rgb[1](d3)
|
279 |
+
output.append(out1)#128#128orH/2,W/2
|
280 |
+
|
281 |
+
d2 = self.Up2(d3)
|
282 |
+
e1_att = self.coatt2(g=d2, x=e1)
|
283 |
+
d2 = torch.cat((e1_att, d2), dim=1)
|
284 |
+
d2 = self.Up_conv2(d2)
|
285 |
+
d2 = self.Up_conv2_1(d2)
|
286 |
+
#32
|
287 |
+
out0=self.feature_to_rgb[0](d2)
|
288 |
+
output.append(out0)#256*256
|
289 |
+
|
290 |
+
#out = self.Conv(d2)
|
291 |
+
|
292 |
+
#d1 = self.active(out)
|
293 |
+
#output=np.array(output)
|
294 |
+
|
295 |
+
return output[3]
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
class Discriminator(nn.Module):
|
301 |
+
def __init__(self, in_channels=3,use_eql=True):
|
302 |
+
super(Discriminator, self).__init__()
|
303 |
+
|
304 |
+
self.use_eql=use_eql
|
305 |
+
self.in_channels=in_channels
|
306 |
+
|
307 |
+
|
308 |
+
#modulelist
|
309 |
+
self.rgb_to_feature1=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)])
|
310 |
+
self.rgb_to_feature2=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)])
|
311 |
+
|
312 |
+
|
313 |
+
self.layer=_equalized_conv2d(self.in_channels*2, 64, (1, 1), bias=True)
|
314 |
+
# pixel_wise feature normalizer:
|
315 |
+
self.pixNorm = PixelwiseNorm()
|
316 |
+
# leaky_relu:
|
317 |
+
self.lrelu = LeakyReLU(0.2)
|
318 |
+
|
319 |
+
|
320 |
+
self.layer0=DisGeneralConvBlock(64,64,use_eql=self.use_eql)
|
321 |
+
#128*128*32
|
322 |
+
|
323 |
+
self.layer1=DisGeneralConvBlock(128,128,use_eql=self.use_eql)
|
324 |
+
#64*64*64
|
325 |
+
|
326 |
+
self.layer2=DisGeneralConvBlock(256,256,use_eql=self.use_eql)
|
327 |
+
#32*32*128
|
328 |
+
|
329 |
+
self.layer3=DisGeneralConvBlock(512,512,use_eql=self.use_eql)
|
330 |
+
#16*16*256
|
331 |
+
|
332 |
+
self.layer4=DisFinalBlock(512,use_eql=self.use_eql)
|
333 |
+
#8*8*512
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
def forward(self, img_A, inputs):
|
338 |
+
#inputs图片尺寸从小到大
|
339 |
+
# Concatenate image and condition image by channels to produce input
|
340 |
+
#img_input = torch.cat((img_A, img_B), 1)
|
341 |
+
#img_A_128= F.interpolate(img_A, size=[128, 128])
|
342 |
+
#img_A_64= F.interpolate(img_A, size=[64, 64])
|
343 |
+
#img_A_32= F.interpolate(img_A, size=[32, 32])
|
344 |
+
|
345 |
+
|
346 |
+
x=torch.cat((img_A[3], inputs[3]), 1)
|
347 |
+
y = self.pixNorm(self.lrelu(self.layer(x)))
|
348 |
+
|
349 |
+
y=self.layer0(y)
|
350 |
+
#128*128*64
|
351 |
+
|
352 |
+
|
353 |
+
x1=self.rgb_to_feature1[0](img_A[2])
|
354 |
+
x2=self.rgb_to_feature2[0](inputs[2])
|
355 |
+
x=torch.cat((x1,x2),1)
|
356 |
+
y=torch.cat((x,y),1)
|
357 |
+
y=self.layer1(y)
|
358 |
+
#64*64*128
|
359 |
+
|
360 |
+
|
361 |
+
x1=self.rgb_to_feature1[1](img_A[1])
|
362 |
+
x2=self.rgb_to_feature2[1](inputs[1])
|
363 |
+
x=torch.cat((x1,x2),1)
|
364 |
+
y=torch.cat((x,y),1)
|
365 |
+
y=self.layer2(y)
|
366 |
+
#32*32*256
|
367 |
+
|
368 |
+
x1=self.rgb_to_feature1[2](img_A[0])
|
369 |
+
x2=self.rgb_to_feature2[2](inputs[0])
|
370 |
+
x=torch.cat((x1,x2),1)
|
371 |
+
y=torch.cat((x,y),1)
|
372 |
+
y=self.layer3(y)
|
373 |
+
#16*16*512
|
374 |
+
|
375 |
+
y=self.layer4(y)
|
376 |
+
#8*8*512
|
377 |
+
|
378 |
+
return y
|
net/__pycache__/CMSFFT.cpython-37.pyc
ADDED
Binary file (11.5 kB). View file
|
|
net/__pycache__/CTrans.cpython-37.pyc
ADDED
Binary file (11.4 kB). View file
|
|
net/__pycache__/IntmdSequential.cpython-37.pyc
ADDED
Binary file (919 Bytes). View file
|
|
net/__pycache__/PositionalEncoding.cpython-37.pyc
ADDED
Binary file (1.78 kB). View file
|
|
net/__pycache__/SGFMT.cpython-37.pyc
ADDED
Binary file (3.98 kB). View file
|
|
net/__pycache__/Transformer.cpython-37.pyc
ADDED
Binary file (3.93 kB). View file
|
|
net/__pycache__/Ushape_Trans.cpython-37.pyc
ADDED
Binary file (6.65 kB). View file
|
|
net/__pycache__/block.cpython-37.pyc
ADDED
Binary file (13.1 kB). View file
|
|
net/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (2.8 kB). View file
|
|
net/block.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch as th
|
4 |
+
import datetime
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import timeit
|
8 |
+
import copy
|
9 |
+
import numpy as np
|
10 |
+
from torch.nn import ModuleList
|
11 |
+
from torch.nn import Conv2d
|
12 |
+
from torch.nn import LeakyReLU
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
#PixelwiseNorm代替了BatchNorm
|
18 |
+
class PixelwiseNorm(th.nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super(PixelwiseNorm, self).__init__()
|
21 |
+
|
22 |
+
def forward(self, x, alpha=1e-8):
|
23 |
+
"""
|
24 |
+
forward pass of the module
|
25 |
+
:param x: input activations volume
|
26 |
+
:param alpha: small number for numerical stability
|
27 |
+
:return: y => pixel normalized activations
|
28 |
+
"""
|
29 |
+
y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW]
|
30 |
+
y = x / y # normalize the input x volume
|
31 |
+
return y
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
class MinibatchStdDev(th.nn.Module):
|
36 |
+
"""
|
37 |
+
Minibatch standard deviation layer for the discriminator
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self):
|
41 |
+
"""
|
42 |
+
derived class constructor
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
def forward(self, x, alpha=1e-8):
|
47 |
+
"""
|
48 |
+
forward pass of the layer
|
49 |
+
:param x: input activation volume
|
50 |
+
:param alpha: small number for numerical stability
|
51 |
+
:return: y => x appended with standard deviation constant map
|
52 |
+
"""
|
53 |
+
batch_size, _, height, width = x.shape
|
54 |
+
|
55 |
+
# [B x C x H x W] Subtract mean over batch.
|
56 |
+
y = x - x.mean(dim=0, keepdim=True)
|
57 |
+
|
58 |
+
# [1 x C x H x W] Calc standard deviation over batch
|
59 |
+
y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha)
|
60 |
+
|
61 |
+
# [1] Take average over feature_maps and pixels.
|
62 |
+
y = y.mean().view(1, 1, 1, 1)
|
63 |
+
|
64 |
+
# [B x 1 x H x W] Replicate over group and pixels.
|
65 |
+
y = y.repeat(batch_size, 1, height, width)
|
66 |
+
|
67 |
+
# [B x C x H x W] Append as new feature_map.
|
68 |
+
y = th.cat([x, y], 1)
|
69 |
+
|
70 |
+
# return the computed values:
|
71 |
+
return y
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
# ==========================================================
|
78 |
+
# Equalized learning rate blocks:
|
79 |
+
# extending Conv2D and Deconv2D layers for equalized learning rate logic
|
80 |
+
# ==========================================================
|
81 |
+
class _equalized_conv2d(th.nn.Module):
|
82 |
+
""" conv2d with the concept of equalized learning rate
|
83 |
+
Args:
|
84 |
+
:param c_in: input channels
|
85 |
+
:param c_out: output channels
|
86 |
+
:param k_size: kernel size (h, w) should be a tuple or a single integer
|
87 |
+
:param stride: stride for conv
|
88 |
+
:param pad: padding
|
89 |
+
:param bias: whether to use bias or not
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
|
93 |
+
""" constructor for the class """
|
94 |
+
from torch.nn.modules.utils import _pair
|
95 |
+
from numpy import sqrt, prod
|
96 |
+
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
# define the weight and bias if to be used
|
100 |
+
self.weight = th.nn.Parameter(th.nn.init.normal_(
|
101 |
+
th.empty(c_out, c_in, *_pair(k_size))
|
102 |
+
))
|
103 |
+
|
104 |
+
self.use_bias = bias
|
105 |
+
self.stride = stride
|
106 |
+
self.pad = pad
|
107 |
+
|
108 |
+
if self.use_bias:
|
109 |
+
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
|
110 |
+
|
111 |
+
fan_in = prod(_pair(k_size)) * c_in # value of fan_in
|
112 |
+
self.scale = sqrt(2) / sqrt(fan_in)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
"""
|
116 |
+
forward pass of the network
|
117 |
+
:param x: input
|
118 |
+
:return: y => output
|
119 |
+
"""
|
120 |
+
from torch.nn.functional import conv2d
|
121 |
+
|
122 |
+
return conv2d(input=x,
|
123 |
+
weight=self.weight * self.scale, # scale the weight on runtime
|
124 |
+
bias=self.bias if self.use_bias else None,
|
125 |
+
stride=self.stride,
|
126 |
+
padding=self.pad)
|
127 |
+
|
128 |
+
def extra_repr(self):
|
129 |
+
return ", ".join(map(str, self.weight.shape))
|
130 |
+
|
131 |
+
|
132 |
+
class _equalized_deconv2d(th.nn.Module):
|
133 |
+
""" Transpose convolution using the equalized learning rate
|
134 |
+
Args:
|
135 |
+
:param c_in: input channels
|
136 |
+
:param c_out: output channels
|
137 |
+
:param k_size: kernel size
|
138 |
+
:param stride: stride for convolution transpose
|
139 |
+
:param pad: padding
|
140 |
+
:param bias: whether to use bias or not
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
|
144 |
+
""" constructor for the class """
|
145 |
+
from torch.nn.modules.utils import _pair
|
146 |
+
from numpy import sqrt
|
147 |
+
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
# define the weight and bias if to be used
|
151 |
+
self.weight = th.nn.Parameter(th.nn.init.normal_(
|
152 |
+
th.empty(c_in, c_out, *_pair(k_size))
|
153 |
+
))
|
154 |
+
|
155 |
+
self.use_bias = bias
|
156 |
+
self.stride = stride
|
157 |
+
self.pad = pad
|
158 |
+
|
159 |
+
if self.use_bias:
|
160 |
+
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
|
161 |
+
|
162 |
+
fan_in = c_in # value of fan_in for deconv
|
163 |
+
self.scale = sqrt(2) / sqrt(fan_in)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
"""
|
167 |
+
forward pass of the layer
|
168 |
+
:param x: input
|
169 |
+
:return: y => output
|
170 |
+
"""
|
171 |
+
from torch.nn.functional import conv_transpose2d
|
172 |
+
|
173 |
+
return conv_transpose2d(input=x,
|
174 |
+
weight=self.weight * self.scale, # scale the weight on runtime
|
175 |
+
bias=self.bias if self.use_bias else None,
|
176 |
+
stride=self.stride,
|
177 |
+
padding=self.pad)
|
178 |
+
|
179 |
+
def extra_repr(self):
|
180 |
+
return ", ".join(map(str, self.weight.shape))
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
#basic block of the encoding part of the genarater
|
185 |
+
#编码器的基本卷积块
|
186 |
+
class conv_block(nn.Module):
|
187 |
+
"""
|
188 |
+
Convolution Block
|
189 |
+
with two convolution layers
|
190 |
+
"""
|
191 |
+
def __init__(self, in_ch, out_ch,use_eql=True):
|
192 |
+
super(conv_block, self).__init__()
|
193 |
+
|
194 |
+
if use_eql:
|
195 |
+
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
|
196 |
+
pad=0, bias=True)
|
197 |
+
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
|
198 |
+
pad=1, bias=True)
|
199 |
+
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
|
200 |
+
pad=1, bias=True)
|
201 |
+
|
202 |
+
else:
|
203 |
+
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
|
204 |
+
padding=1, bias=True)
|
205 |
+
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
|
206 |
+
padding=1, bias=True)
|
207 |
+
|
208 |
+
# pixel_wise feature normalizer:
|
209 |
+
self.pixNorm = PixelwiseNorm()
|
210 |
+
|
211 |
+
# leaky_relu:
|
212 |
+
self.lrelu = LeakyReLU(0.2)
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
"""
|
216 |
+
forward pass of the block
|
217 |
+
:param x: input
|
218 |
+
:return: y => output
|
219 |
+
"""
|
220 |
+
from torch.nn.functional import interpolate
|
221 |
+
|
222 |
+
#y = interpolate(x, scale_factor=2)
|
223 |
+
y=self.conv_1(self.lrelu(self.pixNorm(x)))
|
224 |
+
residual=y
|
225 |
+
y=self.conv_2(self.lrelu(self.pixNorm(y)))
|
226 |
+
y=self.conv_3(self.lrelu(self.pixNorm(y)))
|
227 |
+
y=y+residual
|
228 |
+
|
229 |
+
|
230 |
+
return y
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
#basic up convolution block of the encoding part of the genarater
|
236 |
+
#编码器的基本卷积块
|
237 |
+
class up_conv(nn.Module):
|
238 |
+
"""
|
239 |
+
Up Convolution Block
|
240 |
+
"""
|
241 |
+
def __init__(self, in_ch, out_ch,use_eql=True):
|
242 |
+
super(up_conv, self).__init__()
|
243 |
+
if use_eql:
|
244 |
+
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
|
245 |
+
pad=0, bias=True)
|
246 |
+
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
|
247 |
+
pad=1, bias=True)
|
248 |
+
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
|
249 |
+
pad=1, bias=True)
|
250 |
+
|
251 |
+
else:
|
252 |
+
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
|
253 |
+
padding=1, bias=True)
|
254 |
+
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
|
255 |
+
padding=1, bias=True)
|
256 |
+
|
257 |
+
# pixel_wise feature normalizer:
|
258 |
+
self.pixNorm = PixelwiseNorm()
|
259 |
+
|
260 |
+
# leaky_relu:
|
261 |
+
self.lrelu = LeakyReLU(0.2)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""
|
265 |
+
forward pass of the block
|
266 |
+
:param x: input
|
267 |
+
:return: y => output
|
268 |
+
"""
|
269 |
+
from torch.nn.functional import interpolate
|
270 |
+
|
271 |
+
x = interpolate(x, scale_factor=2, mode="bilinear")
|
272 |
+
y=self.conv_1(self.lrelu(self.pixNorm(x)))
|
273 |
+
residual=y
|
274 |
+
y=self.conv_2(self.lrelu(self.pixNorm(y)))
|
275 |
+
y=self.conv_3(self.lrelu(self.pixNorm(y)))
|
276 |
+
y=y+residual
|
277 |
+
|
278 |
+
return y
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
#判别器的最后一层
|
284 |
+
class DisFinalBlock(th.nn.Module):
|
285 |
+
""" Final block for the Discriminator """
|
286 |
+
|
287 |
+
def __init__(self, in_channels, use_eql=True):
|
288 |
+
"""
|
289 |
+
constructor of the class
|
290 |
+
:param in_channels: number of input channels
|
291 |
+
:param use_eql: whether to use equalized learning rate
|
292 |
+
"""
|
293 |
+
from torch.nn import LeakyReLU
|
294 |
+
from torch.nn import Conv2d
|
295 |
+
|
296 |
+
super().__init__()
|
297 |
+
|
298 |
+
# declare the required modules for forward pass
|
299 |
+
self.batch_discriminator = MinibatchStdDev()
|
300 |
+
|
301 |
+
if use_eql:
|
302 |
+
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3),
|
303 |
+
pad=1, bias=True)
|
304 |
+
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4),stride=2,pad=1,
|
305 |
+
bias=True)
|
306 |
+
|
307 |
+
# final layer emulates the fully connected layer
|
308 |
+
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
|
309 |
+
|
310 |
+
else:
|
311 |
+
# modules required:
|
312 |
+
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
|
313 |
+
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
|
314 |
+
|
315 |
+
# final conv layer emulates a fully connected layer
|
316 |
+
self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)
|
317 |
+
|
318 |
+
# leaky_relu:
|
319 |
+
self.lrelu = LeakyReLU(0.2)
|
320 |
+
|
321 |
+
def forward(self, x):
|
322 |
+
"""
|
323 |
+
forward pass of the FinalBlock
|
324 |
+
:param x: input
|
325 |
+
:return: y => output
|
326 |
+
"""
|
327 |
+
# minibatch_std_dev layer
|
328 |
+
y = self.batch_discriminator(x)
|
329 |
+
|
330 |
+
# define the computations
|
331 |
+
y = self.lrelu(self.conv_1(y))
|
332 |
+
y = self.lrelu(self.conv_2(y))
|
333 |
+
|
334 |
+
# fully connected layer
|
335 |
+
y = self.conv_3(y) # This layer has linear activation
|
336 |
+
|
337 |
+
# flatten the output raw discriminator scores
|
338 |
+
return y
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
#判别器基本卷积块
|
343 |
+
class DisGeneralConvBlock(th.nn.Module):
|
344 |
+
""" General block in the discriminator """
|
345 |
+
|
346 |
+
def __init__(self, in_channels, out_channels, use_eql=True):
|
347 |
+
"""
|
348 |
+
constructor of the class
|
349 |
+
:param in_channels: number of input channels
|
350 |
+
:param out_channels: number of output channels
|
351 |
+
:param use_eql: whether to use equalized learning rate
|
352 |
+
"""
|
353 |
+
from torch.nn import AvgPool2d, LeakyReLU
|
354 |
+
from torch.nn import Conv2d
|
355 |
+
|
356 |
+
super().__init__()
|
357 |
+
|
358 |
+
if use_eql:
|
359 |
+
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3),
|
360 |
+
pad=1, bias=True)
|
361 |
+
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3),
|
362 |
+
pad=1, bias=True)
|
363 |
+
else:
|
364 |
+
# convolutional modules
|
365 |
+
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3),
|
366 |
+
padding=1, bias=True)
|
367 |
+
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3),
|
368 |
+
padding=1, bias=True)
|
369 |
+
|
370 |
+
self.downSampler = AvgPool2d(2) # downsampler
|
371 |
+
|
372 |
+
# leaky_relu:
|
373 |
+
self.lrelu = LeakyReLU(0.2)
|
374 |
+
|
375 |
+
def forward(self, x):
|
376 |
+
"""
|
377 |
+
forward pass of the module
|
378 |
+
:param x: input
|
379 |
+
:return: y => output
|
380 |
+
"""
|
381 |
+
# define the computations
|
382 |
+
y = self.lrelu(self.conv_1(x))
|
383 |
+
y = self.lrelu(self.conv_2(y))
|
384 |
+
y = self.downSampler(y)
|
385 |
+
|
386 |
+
return y
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
class from_rgb(nn.Module):
|
393 |
+
"""
|
394 |
+
The RGB image is transformed into a multi-channel feature map to be concatenated with
|
395 |
+
the feature map with the same number of channels in the network
|
396 |
+
把RGB图转换为多通道特征图,以便与网络中相同通道数的特征图拼接
|
397 |
+
"""
|
398 |
+
def __init__(self, outchannels, use_eql=True):
|
399 |
+
super(from_rgb, self).__init__()
|
400 |
+
if use_eql:
|
401 |
+
self.conv_1 = _equalized_conv2d(3, outchannels, (1, 1), bias=True)
|
402 |
+
else:
|
403 |
+
self.conv_1 = nn.Conv2d(3, outchannels, (1, 1),bias=True)
|
404 |
+
# pixel_wise feature normalizer:
|
405 |
+
self.pixNorm = PixelwiseNorm()
|
406 |
+
|
407 |
+
# leaky_relu:
|
408 |
+
self.lrelu = LeakyReLU(0.2)
|
409 |
+
|
410 |
+
|
411 |
+
def forward(self, x):
|
412 |
+
"""
|
413 |
+
forward pass of the block
|
414 |
+
:param x: input
|
415 |
+
:return: y => output
|
416 |
+
"""
|
417 |
+
y = self.pixNorm(self.lrelu(self.conv_1(x)))
|
418 |
+
return y
|
419 |
+
|
420 |
+
class to_rgb(nn.Module):
|
421 |
+
"""
|
422 |
+
把多通道特征图转换为RGB三通道图,以便输入判别器
|
423 |
+
The multi-channel feature map is converted into RGB image for input to the discriminator
|
424 |
+
"""
|
425 |
+
def __init__(self, inchannels, use_eql=True):
|
426 |
+
super(to_rgb, self).__init__()
|
427 |
+
if use_eql:
|
428 |
+
self.conv_1 = _equalized_conv2d(inchannels, 3, (1, 1), bias=True)
|
429 |
+
else:
|
430 |
+
self.conv_1 = nn.Conv2d(inchannels, 3, (1, 1),bias=True)
|
431 |
+
|
432 |
+
|
433 |
+
|
434 |
+
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
"""
|
438 |
+
forward pass of the block
|
439 |
+
:param x: input
|
440 |
+
:return: y => output
|
441 |
+
"""
|
442 |
+
|
443 |
+
y = self.conv_1(x)
|
444 |
+
|
445 |
+
return y
|
446 |
+
|
447 |
+
class Flatten(nn.Module):
|
448 |
+
def forward(self, x):
|
449 |
+
return x.view(x.size(0), -1)
|
450 |
+
|
451 |
+
|
452 |
+
|
453 |
+
class CCA(nn.Module):
|
454 |
+
"""
|
455 |
+
CCA Block
|
456 |
+
"""
|
457 |
+
def __init__(self, F_g, F_x):
|
458 |
+
super().__init__()
|
459 |
+
self.mlp_x = nn.Sequential(
|
460 |
+
Flatten(),
|
461 |
+
nn.Linear(F_x, F_x))
|
462 |
+
self.mlp_g = nn.Sequential(
|
463 |
+
Flatten(),
|
464 |
+
nn.Linear(F_g, F_x))
|
465 |
+
self.relu = nn.ReLU(inplace=True)
|
466 |
+
|
467 |
+
def forward(self, g, x):
|
468 |
+
# channel-wise attention
|
469 |
+
avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
470 |
+
channel_att_x = self.mlp_x(avg_pool_x)
|
471 |
+
avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
|
472 |
+
channel_att_g = self.mlp_g(avg_pool_g)
|
473 |
+
channel_att_sum = (channel_att_x + channel_att_g)/2.0
|
474 |
+
scale = th.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
|
475 |
+
x_after_channel = x * scale
|
476 |
+
out = self.relu(x_after_channel)
|
477 |
+
return out
|
net/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from skimage.measure.simple_metrics import compare_psnr
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
|
9 |
+
def weights_init_kaiming(m):
|
10 |
+
classname = m.__class__.__name__
|
11 |
+
if classname.find('Conv') != -1:
|
12 |
+
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
13 |
+
elif classname.find('Linear') != -1:
|
14 |
+
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
15 |
+
elif classname.find('BatchNorm') != -1:
|
16 |
+
# nn.init.uniform(m.weight.data, 1.0, 0.02)
|
17 |
+
m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
|
18 |
+
nn.init.constant(m.bias.data, 0.0)
|
19 |
+
|
20 |
+
class VGG19_PercepLoss(nn.Module):
|
21 |
+
""" Calculates perceptual loss in vgg19 space
|
22 |
+
"""
|
23 |
+
def __init__(self, _pretrained_=True):
|
24 |
+
super(VGG19_PercepLoss, self).__init__()
|
25 |
+
self.vgg = models.vgg19(pretrained=_pretrained_).features
|
26 |
+
for param in self.vgg.parameters():
|
27 |
+
param.requires_grad_(False)
|
28 |
+
|
29 |
+
def get_features(self, image, layers=None):
|
30 |
+
if layers is None:
|
31 |
+
layers = {'30': 'conv5_2'} # may add other layers
|
32 |
+
features = {}
|
33 |
+
x = image
|
34 |
+
for name, layer in self.vgg._modules.items():
|
35 |
+
x = layer(x)
|
36 |
+
if name in layers:
|
37 |
+
features[layers[name]] = x
|
38 |
+
return features
|
39 |
+
|
40 |
+
def forward(self, pred, true, layer='conv5_2'):
|
41 |
+
true_f = self.get_features(true)
|
42 |
+
pred_f = self.get_features(pred)
|
43 |
+
return torch.mean((true_f[layer]-pred_f[layer])**2)
|
44 |
+
|
45 |
+
|
46 |
+
def batch_PSNR(img, imclean, data_range):
|
47 |
+
Img = img.data.cpu().numpy().astype(np.float32)
|
48 |
+
Iclean = imclean.data.cpu().numpy().astype(np.float32)
|
49 |
+
PSNR = 0
|
50 |
+
for i in range(Img.shape[0]):
|
51 |
+
PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
|
52 |
+
return (PSNR/Img.shape[0])
|
53 |
+
|
54 |
+
def data_augmentation(image, mode):
|
55 |
+
out = np.transpose(image, (1,2,0))
|
56 |
+
#out = image
|
57 |
+
if mode == 0:
|
58 |
+
# original
|
59 |
+
out = out
|
60 |
+
elif mode == 1:
|
61 |
+
# flip up and down
|
62 |
+
out = np.flipud(out)
|
63 |
+
elif mode == 2:
|
64 |
+
# rotate counterwise 90 degree
|
65 |
+
out = np.rot90(out)
|
66 |
+
elif mode == 3:
|
67 |
+
# rotate 90 degree and flip up and down
|
68 |
+
out = np.rot90(out)
|
69 |
+
out = np.flipud(out)
|
70 |
+
elif mode == 4:
|
71 |
+
# rotate 180 degree
|
72 |
+
out = np.rot90(out, k=2)
|
73 |
+
elif mode == 5:
|
74 |
+
# rotate 180 degree and flip
|
75 |
+
out = np.rot90(out, k=2)
|
76 |
+
out = np.flipud(out)
|
77 |
+
elif mode == 6:
|
78 |
+
# rotate 270 degree
|
79 |
+
out = np.rot90(out, k=3)
|
80 |
+
elif mode == 7:
|
81 |
+
# rotate 270 degree and flip
|
82 |
+
out = np.rot90(out, k=3)
|
83 |
+
out = np.flipud(out)
|
84 |
+
return np.transpose(out, (2,0,1))
|
85 |
+
#return out
|
86 |
+
|