AlexZou commited on
Commit
d03bb00
1 Parent(s): 2f3486f

Upload 3 files

Browse files
Files changed (3) hide show
  1. model/IAT_main.py +133 -0
  2. model/blocks.py +281 -0
  3. model/global_net.py +132 -0
model/IAT_main.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import os
6
+ import math
7
+
8
+ from timm.models.layers import trunc_normal_
9
+ from model.blocks import CBlock_ln, SwinTransformerBlock
10
+ from model.global_net import Global_pred
11
+
12
+ class Local_pred(nn.Module):
13
+ def __init__(self, dim=16, number=4, type='ccc'):
14
+ super(Local_pred, self).__init__()
15
+ # initial convolution
16
+ self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
17
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
18
+ # main blocks
19
+ block = CBlock_ln(dim)
20
+ block_t = SwinTransformerBlock(dim) # head number
21
+ if type =='ccc':
22
+ #blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
23
+ blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
24
+ blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
25
+ elif type =='ttt':
26
+ blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
27
+ elif type =='cct':
28
+ blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
29
+ # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
30
+ self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
31
+ self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
32
+
33
+
34
+ def forward(self, img):
35
+ img1 = self.relu(self.conv1(img))
36
+ mul = self.mul_blocks(img1)
37
+ add = self.add_blocks(img1)
38
+
39
+ return mul, add
40
+
41
+ # Short Cut Connection on Final Layer
42
+ class Local_pred_S(nn.Module):
43
+ def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
44
+ super(Local_pred_S, self).__init__()
45
+ # initial convolution
46
+ self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
47
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
48
+ # main blocks
49
+ block = CBlock_ln(dim)
50
+ block_t = SwinTransformerBlock(dim) # head number
51
+ if type =='ccc':
52
+ blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
53
+ blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
54
+ elif type =='ttt':
55
+ blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
56
+ elif type =='cct':
57
+ blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
58
+ # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
59
+ self.mul_blocks = nn.Sequential(*blocks1)
60
+ self.add_blocks = nn.Sequential(*blocks2)
61
+
62
+ self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
63
+ self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
64
+ self.apply(self._init_weights)
65
+
66
+ def _init_weights(self, m):
67
+ if isinstance(m, nn.Linear):
68
+ trunc_normal_(m.weight, std=.02)
69
+ if isinstance(m, nn.Linear) and m.bias is not None:
70
+ nn.init.constant_(m.bias, 0)
71
+ elif isinstance(m, nn.LayerNorm):
72
+ nn.init.constant_(m.bias, 0)
73
+ nn.init.constant_(m.weight, 1.0)
74
+ elif isinstance(m, nn.Conv2d):
75
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
76
+ fan_out //= m.groups
77
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
78
+ if m.bias is not None:
79
+ m.bias.data.zero_()
80
+
81
+
82
+
83
+ def forward(self, img):
84
+ img1 = self.relu(self.conv1(img))
85
+ # short cut connection
86
+ mul = self.mul_blocks(img1) + img1
87
+ add = self.add_blocks(img1) + img1
88
+ mul = self.mul_end(mul)
89
+ add = self.add_end(add)
90
+
91
+ return mul, add
92
+
93
+ class IAT(nn.Module):
94
+ def __init__(self, in_dim=3, with_global=True, type='lol'):
95
+ super(IAT, self).__init__()
96
+ #self.local_net = Local_pred()
97
+
98
+ self.local_net = Local_pred_S(in_dim=in_dim)
99
+
100
+ self.with_global = with_global
101
+ if self.with_global:
102
+ self.global_net = Global_pred(in_channels=in_dim, type=type)
103
+
104
+ def apply_color(self, image, ccm):
105
+ shape = image.shape
106
+ image = image.view(-1, 3)
107
+ image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
108
+ image = image.view(shape)
109
+ return torch.clamp(image, 1e-8, 1.0)
110
+
111
+ def forward(self, img_low):
112
+ #print(self.with_global)
113
+ mul, add = self.local_net(img_low)
114
+ img_high = (img_low.mul(mul)).add(add)
115
+
116
+ if not self.with_global:
117
+ return img_high
118
+
119
+ else:
120
+ gamma, color = self.global_net(img_low)
121
+ b = img_high.shape[0]
122
+ img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
123
+ img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
124
+ img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
125
+ return img_high
126
+
127
+
128
+ if __name__ == "__main__":
129
+ os.environ['CUDA_VISIBLE_DEVICES']='3'
130
+ img = torch.Tensor(1, 3, 400, 600)
131
+ net = IAT()
132
+ print('total parameters:', sum(param.numel() for param in net.parameters()))
133
+ _, _, high = net(img)
model/blocks.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code copy from uniformer source code:
3
+ https://github.com/Sense-X/UniFormer
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ from functools import partial
9
+ import math
10
+ from timm.models.vision_transformer import VisionTransformer, _cfg
11
+ from timm.models.registry import register_model
12
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
13
+
14
+ # ResMLP's normalization
15
+ class Aff(nn.Module):
16
+ def __init__(self, dim):
17
+ super().__init__()
18
+ # learnable
19
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
20
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
21
+
22
+ def forward(self, x):
23
+ x = x * self.alpha + self.beta
24
+ return x
25
+
26
+ # Color Normalization
27
+ class Aff_channel(nn.Module):
28
+ def __init__(self, dim, channel_first = True):
29
+ super().__init__()
30
+ # learnable
31
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
32
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
33
+ self.color = nn.Parameter(torch.eye(dim))
34
+ self.channel_first = channel_first
35
+
36
+ def forward(self, x):
37
+ if self.channel_first:
38
+ x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
39
+ x2 = x1 * self.alpha + self.beta
40
+ else:
41
+ x1 = x * self.alpha + self.beta
42
+ x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
43
+ return x2
44
+
45
+ class Mlp(nn.Module):
46
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
47
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
48
+ super().__init__()
49
+ out_features = out_features or in_features
50
+ hidden_features = hidden_features or in_features
51
+ self.fc1 = nn.Linear(in_features, hidden_features)
52
+ self.act = act_layer()
53
+ self.fc2 = nn.Linear(hidden_features, out_features)
54
+ self.drop = nn.Dropout(drop)
55
+
56
+ def forward(self, x):
57
+ x = self.fc1(x)
58
+ x = self.act(x)
59
+ x = self.drop(x)
60
+ x = self.fc2(x)
61
+ x = self.drop(x)
62
+ return x
63
+
64
+ class CMlp(nn.Module):
65
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
66
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
67
+ super().__init__()
68
+ out_features = out_features or in_features
69
+ hidden_features = hidden_features or in_features
70
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
71
+ self.act = act_layer()
72
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
73
+ self.drop = nn.Dropout(drop)
74
+
75
+ def forward(self, x):
76
+ x = self.fc1(x)
77
+ x = self.act(x)
78
+ x = self.drop(x)
79
+ x = self.fc2(x)
80
+ x = self.drop(x)
81
+ return x
82
+
83
+ class CBlock_ln(nn.Module):
84
+ def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
85
+ drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
86
+ super().__init__()
87
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
88
+ #self.norm1 = Aff_channel(dim)
89
+ self.norm1 = norm_layer(dim)
90
+ self.conv1 = nn.Conv2d(dim, dim, 1)
91
+ self.conv2 = nn.Conv2d(dim, dim, 1)
92
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
93
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ #self.norm2 = Aff_channel(dim)
96
+ self.norm2 = norm_layer(dim)
97
+ mlp_hidden_dim = int(dim * mlp_ratio)
98
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
99
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
100
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
101
+
102
+ def forward(self, x):
103
+ x = x + self.pos_embed(x)
104
+ B, C, H, W = x.shape
105
+ #print(x.shape)
106
+ norm_x = x.flatten(2).transpose(1, 2)
107
+ #print(norm_x.shape)
108
+ norm_x = self.norm1(norm_x)
109
+ norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
110
+
111
+
112
+ x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
113
+ norm_x = x.flatten(2).transpose(1, 2)
114
+ norm_x = self.norm2(norm_x)
115
+ norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
116
+ x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
117
+ return x
118
+
119
+
120
+ def window_partition(x, window_size):
121
+ """
122
+ Args:
123
+ x: (B, H, W, C)
124
+ window_size (int): window size
125
+ Returns:
126
+ windows: (num_windows*B, window_size, window_size, C)
127
+ """
128
+ B, H, W, C = x.shape
129
+ #print(x.shape)
130
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
131
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
132
+ return windows
133
+
134
+
135
+ def window_reverse(windows, window_size, H, W):
136
+ """
137
+ Args:
138
+ windows: (num_windows*B, window_size, window_size, C)
139
+ window_size (int): Window size
140
+ H (int): Height of image
141
+ W (int): Width of image
142
+ Returns:
143
+ x: (B, H, W, C)
144
+ """
145
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
146
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
147
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
148
+ return x
149
+
150
+
151
+ class WindowAttention(nn.Module):
152
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
153
+ It supports both of shifted and non-shifted window.
154
+ Args:
155
+ dim (int): Number of input channels.
156
+ window_size (tuple[int]): The height and width of the window.
157
+ num_heads (int): Number of attention heads.
158
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
159
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
160
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
161
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
162
+ """
163
+
164
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
165
+ super().__init__()
166
+ self.dim = dim
167
+ self.window_size = window_size # Wh, Ww
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+ self.scale = qk_scale or head_dim ** -0.5
171
+
172
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+
177
+ self.softmax = nn.Softmax(dim=-1)
178
+
179
+ def forward(self, x):
180
+ B_, N, C = x.shape
181
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
182
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
183
+
184
+ q = q * self.scale
185
+ attn = (q @ k.transpose(-2, -1))
186
+
187
+ attn = self.softmax(attn)
188
+
189
+ attn = self.attn_drop(attn)
190
+
191
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
192
+ x = self.proj(x)
193
+ x = self.proj_drop(x)
194
+ return x
195
+
196
+ ## Layer_norm, Aff_norm, Aff_channel_norm
197
+ class SwinTransformerBlock(nn.Module):
198
+ r""" Swin Transformer Block.
199
+ Args:
200
+ dim (int): Number of input channels.
201
+ input_resolution (tuple[int]): Input resulotion.
202
+ num_heads (int): Number of attention heads.
203
+ window_size (int): Window size.
204
+ shift_size (int): Shift size for SW-MSA.
205
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
206
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
207
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
208
+ drop (float, optional): Dropout rate. Default: 0.0
209
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
210
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
211
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
212
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
213
+ """
214
+
215
+ def __init__(self, dim, num_heads=2, window_size=8, shift_size=0,
216
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
217
+ act_layer=nn.GELU, norm_layer=Aff_channel):
218
+ super().__init__()
219
+ self.dim = dim
220
+ self.num_heads = num_heads
221
+ self.window_size = window_size
222
+ self.shift_size = shift_size
223
+ self.mlp_ratio = mlp_ratio
224
+
225
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
226
+ #self.norm1 = norm_layer(dim)
227
+ self.norm1 = norm_layer(dim)
228
+ self.attn = WindowAttention(
229
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
230
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
231
+
232
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
233
+ #self.norm2 = norm_layer(dim)
234
+ self.norm2 = norm_layer(dim)
235
+ mlp_hidden_dim = int(dim * mlp_ratio)
236
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
237
+
238
+ def forward(self, x):
239
+ x = x + self.pos_embed(x)
240
+ B, C, H, W = x.shape
241
+ x = x.flatten(2).transpose(1, 2)
242
+
243
+ shortcut = x
244
+ x = self.norm1(x)
245
+ x = x.view(B, H, W, C)
246
+
247
+ # cyclic shift
248
+ if self.shift_size > 0:
249
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
250
+ else:
251
+ shifted_x = x
252
+
253
+ # partition windows
254
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
255
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
256
+
257
+ # W-MSA/SW-MSA
258
+ attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
259
+
260
+ # merge windows
261
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
262
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
263
+
264
+ x = shifted_x
265
+ x = x.view(B, H * W, C)
266
+
267
+ # FFN
268
+ x = shortcut + self.drop_path(x)
269
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
270
+ x = x.transpose(1, 2).reshape(B, C, H, W)
271
+
272
+ return x
273
+
274
+
275
+ if __name__ == "__main__":
276
+ os.environ['CUDA_VISIBLE_DEVICES']='1'
277
+ cb_blovk = CBlock_ln(dim = 16)
278
+ x = torch.Tensor(1, 16, 400, 600)
279
+ swin = SwinTransformerBlock(dim=16, num_heads=4)
280
+ x = cb_blovk(x)
281
+ print(x.shape)
model/global_net.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imp
2
+ import torch
3
+ import torch.nn as nn
4
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
5
+ import os
6
+ from model.blocks import Mlp
7
+
8
+
9
+ class query_Attention(nn.Module):
10
+ def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
11
+ super().__init__()
12
+ self.num_heads = num_heads
13
+ head_dim = dim // num_heads
14
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
15
+ self.scale = qk_scale or head_dim ** -0.5
16
+
17
+ self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
18
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
19
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
20
+ self.attn_drop = nn.Dropout(attn_drop)
21
+ self.proj = nn.Linear(dim, dim)
22
+ self.proj_drop = nn.Dropout(proj_drop)
23
+
24
+ def forward(self, x):
25
+ B, N, C = x.shape
26
+ k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
27
+ v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
28
+ q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
29
+
30
+ # k = self.k(x).reshape(B, N, self.num_heads, torch.div(C,self.num_heads, rounding_mode='floor')).permute(0, 2, 1, 3)
31
+ # v = self.v(x).reshape(B, N, self.num_heads, torch.div(C,self.num_heads, rounding_mode='floor')).permute(0, 2, 1, 3)
32
+ # q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, torch.div(C,self.num_heads, rounding_mode='floor')).permute(0, 2, 1, 3)
33
+ attn = (q @ k.transpose(-2, -1)) * self.scale
34
+ attn = attn.softmax(dim=-1)
35
+ attn = self.attn_drop(attn)
36
+
37
+ x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
38
+ x = self.proj(x)
39
+ x = self.proj_drop(x)
40
+ return x
41
+
42
+
43
+ class query_SABlock(nn.Module):
44
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
45
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
46
+ super().__init__()
47
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
48
+ self.norm1 = norm_layer(dim)
49
+ self.attn = query_Attention(
50
+ dim,
51
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
52
+ attn_drop=attn_drop, proj_drop=drop)
53
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
54
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
55
+ self.norm2 = norm_layer(dim)
56
+ mlp_hidden_dim = int(dim * mlp_ratio)
57
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
58
+
59
+ def forward(self, x):
60
+ x = x + self.pos_embed(x)
61
+ x = x.flatten(2).transpose(1, 2)
62
+ x = self.drop_path(self.attn(self.norm1(x)))
63
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
64
+ return x
65
+
66
+
67
+ class conv_embedding(nn.Module):
68
+ def __init__(self, in_channels, out_channels):
69
+ super(conv_embedding, self).__init__()
70
+ self.proj = nn.Sequential(
71
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
72
+ nn.BatchNorm2d(out_channels // 2),
73
+ nn.GELU(),
74
+ # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
75
+ # nn.BatchNorm2d(out_channels // 2),
76
+ # nn.GELU(),
77
+ nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
78
+ nn.BatchNorm2d(out_channels),
79
+ )
80
+
81
+ def forward(self, x):
82
+ x = self.proj(x)
83
+ return x
84
+
85
+
86
+ class Global_pred(nn.Module):
87
+ def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
88
+ super(Global_pred, self).__init__()
89
+ if type == 'exp':
90
+ self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
91
+ else:
92
+ self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)
93
+ self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True) # basic color matrix
94
+ # main blocks
95
+ self.conv_large = conv_embedding(in_channels, out_channels)
96
+ self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
97
+ self.gamma_linear = nn.Linear(out_channels, 1)
98
+ self.color_linear = nn.Linear(out_channels, 1)
99
+
100
+ self.apply(self._init_weights)
101
+
102
+ for name, p in self.named_parameters():
103
+ if name == 'generator.attn.v.weight':
104
+ nn.init.constant_(p, 0)
105
+
106
+ def _init_weights(self, m):
107
+ if isinstance(m, nn.Linear):
108
+ trunc_normal_(m.weight, std=.02)
109
+ if isinstance(m, nn.Linear) and m.bias is not None:
110
+ nn.init.constant_(m.bias, 0)
111
+ elif isinstance(m, nn.LayerNorm):
112
+ nn.init.constant_(m.bias, 0)
113
+ nn.init.constant_(m.weight, 1.0)
114
+
115
+
116
+ def forward(self, x):
117
+ #print(self.gamma_base)
118
+ x = self.conv_large(x)
119
+ x = self.generator(x)
120
+ gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
121
+ gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
122
+ #print(self.gamma_base, self.gamma_linear(gamma))
123
+ color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
124
+ return gamma, color
125
+
126
+ if __name__ == "__main__":
127
+ os.environ['CUDA_VISIBLE_DEVICES']='3'
128
+ #net = Local_pred_new().cuda()
129
+ img = torch.Tensor(8, 3, 400, 600)
130
+ global_net = Global_pred()
131
+ gamma, color = global_net(img)
132
+ print(gamma.shape, color.shape)