Meloo commited on
Commit
2dab112
1 Parent(s): dfccca1

Create models/safmn_arch.py

Browse files
Files changed (1) hide show
  1. models/safmn_arch.py +113 -0
models/safmn_arch.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # Layer Norm
7
+ class LayerNorm(nn.Module):
8
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
11
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
12
+ self.eps = eps
13
+ self.data_format = data_format
14
+ if self.data_format not in ["channels_last", "channels_first"]:
15
+ raise NotImplementedError
16
+ self.normalized_shape = (normalized_shape, )
17
+
18
+ def forward(self, x):
19
+ if self.data_format == "channels_last":
20
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
21
+ elif self.data_format == "channels_first":
22
+ u = x.mean(1, keepdim=True)
23
+ s = (x - u).pow(2).mean(1, keepdim=True)
24
+ x = (x - u) / torch.sqrt(s + self.eps)
25
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
26
+ return x
27
+
28
+ # CCM
29
+ class CCM(nn.Module):
30
+ def __init__(self, dim, growth_rate=2.0):
31
+ super().__init__()
32
+ hidden_dim = int(dim * growth_rate)
33
+
34
+ self.ccm = nn.Sequential(
35
+ nn.Conv2d(dim, hidden_dim, 3, 1, 1),
36
+ nn.GELU(),
37
+ nn.Conv2d(hidden_dim, dim, 1, 1, 0)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.ccm(x)
42
+
43
+
44
+ # SAFM
45
+ class SAFM(nn.Module):
46
+ def __init__(self, dim, n_levels=4):
47
+ super().__init__()
48
+ self.n_levels = n_levels
49
+ chunk_dim = dim // n_levels
50
+
51
+ # Spatial Weighting
52
+ self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
53
+
54
+ # # Feature Aggregation
55
+ self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
56
+
57
+ # Activation
58
+ self.act = nn.GELU()
59
+
60
+ def forward(self, x):
61
+ h, w = x.size()[-2:]
62
+
63
+ xc = x.chunk(self.n_levels, dim=1)
64
+ out = []
65
+ for i in range(self.n_levels):
66
+ if i > 0:
67
+ p_size = (h//2**i, w//2**i)
68
+ s = F.adaptive_max_pool2d(xc[i], p_size)
69
+ s = self.mfr[i](s)
70
+ s = F.interpolate(s, size=(h, w), mode='nearest')
71
+ else:
72
+ s = self.mfr[i](xc[i])
73
+ out.append(s)
74
+
75
+ out = self.aggr(torch.cat(out, dim=1))
76
+ out = self.act(out) * x
77
+ return out
78
+
79
+ class AttBlock(nn.Module):
80
+ def __init__(self, dim, ffn_scale=2.0):
81
+ super().__init__()
82
+
83
+ self.norm1 = LayerNorm(dim)
84
+ self.norm2 = LayerNorm(dim)
85
+
86
+ # Multiscale Block
87
+ self.safm = SAFM(dim)
88
+ # Feedforward layer
89
+ self.ccm = CCM(dim, ffn_scale)
90
+
91
+ def forward(self, x):
92
+ x = self.safm(self.norm1(x)) + x
93
+ x = self.ccm(self.norm2(x)) + x
94
+ return x
95
+
96
+
97
+ class SAFMN(nn.Module):
98
+ def __init__(self, dim, n_blocks=8, ffn_scale=2.0, upscaling_factor=4):
99
+ super().__init__()
100
+ self.to_feat = nn.Conv2d(3, dim, 3, 1, 1)
101
+
102
+ self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
103
+
104
+ self.to_img = nn.Sequential(
105
+ nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1),
106
+ nn.PixelShuffle(upscaling_factor)
107
+ )
108
+
109
+ def forward(self, x):
110
+ x = self.to_feat(x)
111
+ x = self.feats(x) + x
112
+ x = self.to_img(x)
113
+ return x