3587jjh commited on
Commit
61522a1
1 Parent(s): d6922e8

Upload 10 files

Browse files
Files changed (10) hide show
  1. carn-pcsr-phase1.pth +3 -0
  2. demo.py +58 -0
  3. models/__init__.py +3 -0
  4. models/carn.py +78 -0
  5. models/mlp.py +32 -0
  6. models/models.py +23 -0
  7. models/pcsr.py +197 -0
  8. models/sampler.py +40 -0
  9. models/utils.py +214 -0
  10. utils.py +210 -0
carn-pcsr-phase1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfe84ddd3923b35d14a977dabda5613a9d59da3b0961e004be786f108d3f8508
3
+ size 755333
demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import models
3
+ from torchvision import transforms
4
+ from utils import *
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ img_path = '/workspace/datasets/test/myimage/HR/X4/FOTO-BOX-18-1024x1024.png' # only support .png
9
+ scale = 4 # only support x4
10
+
11
+ '''
12
+ k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
13
+ adaptive: whether to use automatic decision of k
14
+ no_refinement: whether not to use pixel-wise refinement (postprocessing for reducing artifacts)
15
+ parser.add_argument('--opacity', type=float, default=0.65, help='opacity for colored visualization')
16
+ parser.add_argument('--pixel_batch_size', type=int, default=300000)
17
+ '''
18
+
19
+ resume_path = 'carn-pcsr-phase1.pth'
20
+ sv_file = torch.load(resume_path)
21
+ model = models.make(sv_file['model'], load_sd=True).cuda()
22
+ model.eval()
23
+
24
+ rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
25
+ rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)
26
+
27
+ with torch.no_grad():
28
+ # prepare inputs
29
+ lr = transforms.ToTensor()(Image.open(img_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
30
+ h,w = lr.shape[-2:]
31
+ H,W = h*scale, w*scale
32
+ coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
33
+ cell = torch.ones_like(coord)
34
+ cell[:,:,0] *= 2/H
35
+ cell[:,:,1] *= 2/W
36
+ inp_lr = (lr - rgb_mean) / rgb_std
37
+
38
+ pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
39
+ pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
40
+ flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
41
+ pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
42
+ max_flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=-25,
43
+ pixel_batch_size=300000, adaptive_cluster=False, refinement=True)
44
+ print('flops: {:.1f}G ({:.1f} %) | max_flops: {:.1f}G (100 %)'.format(flops/1e9,
45
+ (flops / max_flops)*100, max_flops/1e9))
46
+
47
+ pred = pred.transpose(1,2).view(-1,3,H,W)
48
+ pred = pred * rgb_std + rgb_mean
49
+ pred = tensor2numpy(pred)
50
+ Image.fromarray(pred).save(f'output.png')
51
+
52
+ flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
53
+ H,W = pred.shape[:2]
54
+ vis_img = np.zeros_like(pred)
55
+ vis_img[flag[0] == 0] = np.array([0,255,0])
56
+ vis_img[flag[0] == 1] = np.array([255,0,0])
57
+ vis_img = vis_img*0.35 + pred*0.65
58
+ Image.fromarray(vis_img.astype('uint8')).save('output_vis.png')
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import register, make
2
+ from . import mlp, pcsr, sampler
3
+ from . import carn
models/carn.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import models.utils as mutils
6
+ from models import register
7
+
8
+
9
+ class Block(nn.Module):
10
+ def __init__(self, nf, group=1):
11
+ super(Block, self).__init__()
12
+ self.b1 = mutils.EResidualBlock(nf, nf, group=group)
13
+ self.c1 = mutils.BasicBlock(nf*2, nf, 1, 1, 0)
14
+ self.c2 = mutils.BasicBlock(nf*3, nf, 1, 1, 0)
15
+ self.c3 = mutils.BasicBlock(nf*4, nf, 1, 1, 0)
16
+
17
+ def forward(self, x):
18
+ c0 = o0 = x
19
+
20
+ b1 = self.b1(o0)
21
+ c1 = torch.cat([c0, b1], dim=1)
22
+ o1 = self.c1(c1)
23
+
24
+ b2 = self.b1(o1)
25
+ c2 = torch.cat([c1, b2], dim=1)
26
+ o2 = self.c2(c2)
27
+
28
+ b3 = self.b1(o2)
29
+ c3 = torch.cat([c2, b3], dim=1)
30
+ o3 = self.c3(c3)
31
+
32
+ return o3
33
+
34
+
35
+ @register('carn')
36
+ class CARN_M(nn.Module):
37
+ def __init__(self, in_nc=3, out_nc=3, nf=64, scale=4, group=4, no_upsampling=False):
38
+ super(CARN_M, self).__init__()
39
+ self.scale = scale
40
+ self.out_dim = nf
41
+
42
+ self.entry = nn.Conv2d(in_nc, nf, 3, 1, 1)
43
+ self.b1 = Block(nf, group=group)
44
+ self.b2 = Block(nf, group=group)
45
+ self.b3 = Block(nf, group=group)
46
+
47
+ self.c1 = mutils.BasicBlock(nf*2, nf, 1, 1, 0)
48
+ self.c2 = mutils.BasicBlock(nf*3, nf, 1, 1, 0)
49
+ self.c3 = mutils.BasicBlock(nf*4, nf, 1, 1, 0)
50
+
51
+ self.no_upsampling = no_upsampling
52
+ if not no_upsampling:
53
+ self.upsample = mutils.UpsampleBlock(nf, scale=scale, multi_scale=False, group=group)
54
+ self.exit = nn.Conv2d(nf, out_nc, 3, 1, 1)
55
+
56
+ def forward(self, x):
57
+ #x = self.sub_mean(x)
58
+ x = self.entry(x)
59
+ c0 = o0 = x
60
+
61
+ b1 = self.b1(o0)
62
+ c1 = torch.cat([c0, b1], dim=1)
63
+ o1 = self.c1(c1)
64
+
65
+ b2 = self.b2(o1)
66
+ c2 = torch.cat([c1, b2], dim=1)
67
+ o2 = self.c2(c2)
68
+
69
+ b3 = self.b3(o2)
70
+ c3 = torch.cat([c2, b3], dim=1)
71
+ o3 = self.c3(c3)
72
+ out = o3.clone()
73
+
74
+ if not self.no_upsampling:
75
+ out = self.upsample(out, scale=self.scale)
76
+ out = self.exit(out)
77
+ #out = self.add_mean(out)
78
+ return out
models/mlp.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models import register
5
+
6
+
7
+ @register('mlp')
8
+ class MLP(nn.Module):
9
+
10
+ def __init__(self, in_dim, out_dim, hidden_list, residual=False):
11
+ super().__init__()
12
+ self.in_dim = in_dim
13
+ self.out_dim = out_dim
14
+ self.hidden_list = hidden_list
15
+ self.residual = residual
16
+ if residual:
17
+ self.convert = nn.Linear(in_dim, out_dim)
18
+
19
+ layers = []
20
+ lastv = in_dim
21
+ for hidden in hidden_list:
22
+ layers.append(nn.Linear(lastv, hidden))
23
+ layers.append(nn.ReLU())
24
+ lastv = hidden
25
+ layers.append(nn.Linear(lastv, out_dim))
26
+ self.layers = nn.Sequential(*layers)
27
+
28
+ def forward(self, x):
29
+ y = self.layers(x)
30
+ if self.residual:
31
+ y = y + self.convert(x)
32
+ return y
models/models.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+
4
+ models = {}
5
+
6
+
7
+ def register(name):
8
+ def decorator(cls):
9
+ models[name] = cls
10
+ return cls
11
+ return decorator
12
+
13
+
14
+ def make(model_spec, args=None, load_sd=False):
15
+ if args is not None:
16
+ model_args = copy.deepcopy(model_spec['args'])
17
+ model_args.update(args)
18
+ else:
19
+ model_args = model_spec['args']
20
+ model = models[model_spec['name']](**model_args)
21
+ if load_sd:
22
+ model.load_state_dict(model_spec['sd'])
23
+ return model
models/pcsr.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import models
8
+ from models import register
9
+ from fast_pytorch_kmeans import KMeans
10
+ from utils import *
11
+
12
+
13
+ @register('pcsr-phase0')
14
+ class PCSR(nn.Module):
15
+ def __init__(self, encoder_spec, heavy_sampler_spec):
16
+ super().__init__()
17
+ self.encoder = models.make(encoder_spec)
18
+ in_dim = self.encoder.out_dim
19
+ self.heavy_sampler = models.make(heavy_sampler_spec,
20
+ args={'in_dim': in_dim, 'out_dim': 3})
21
+
22
+ def forward(self, lr, coord, cell, **kwargs):
23
+ if self.training:
24
+ return self.forward_train(lr, coord, cell)
25
+ else:
26
+ return self.forward_test(lr, coord, cell, **kwargs)
27
+
28
+ def forward_train(self, lr, coord, cell):
29
+ feat = self.encoder(lr)
30
+ res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
31
+ padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
32
+ pred_heavy = self.heavy_sampler(feat, coord, cell) + res
33
+ return pred_heavy
34
+
35
+ def forward_test(self, lr, coord, cell, pixel_batch_size=None):
36
+ feat = self.encoder(lr)
37
+ b,q = coord.shape[:2]
38
+ tot = b*q
39
+ if not pixel_batch_size:
40
+ pixel_batch_size = q
41
+
42
+ preds = []
43
+ for i in range(b): # for each image
44
+ pred = torch.zeros((q,3), device=lr.device)
45
+ l = 0
46
+ while l < q:
47
+ r = min(q, l+pixel_batch_size)
48
+ coord_split = coord[i:i+1,l:r,:]
49
+ cell_split = cell[i:i+1,l:r,:]
50
+ res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
51
+ padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
52
+ pred[l:r] = self.heavy_sampler(feat[i:i+1], coord_split, cell_split) + res
53
+ l = r
54
+ preds.append(pred)
55
+ pred = torch.stack(preds, dim=0)
56
+ return pred
57
+
58
+
59
+ @register('pcsr-phase1')
60
+ class PCSR(nn.Module):
61
+
62
+ def __init__(self, encoder_spec, heavy_sampler_spec, light_sampler_spec, classifier_spec):
63
+ super().__init__()
64
+ self.encoder = models.make(encoder_spec)
65
+ in_dim = self.encoder.out_dim
66
+ self.heavy_sampler = models.make(heavy_sampler_spec,
67
+ args={'in_dim': in_dim, 'out_dim': 3})
68
+ self.light_sampler = models.make(light_sampler_spec,
69
+ args={'in_dim': in_dim, 'out_dim': 3})
70
+ self.classifier = models.make(classifier_spec,
71
+ args={'in_dim': in_dim, 'out_dim': 2})
72
+ self.kmeans = KMeans(n_clusters=2, max_iter=20, mode='euclidean', verbose=0)
73
+ self.cost_list = {}
74
+
75
+ def forward(self, lr, coord, cell, **kwargs):
76
+ if self.training:
77
+ return self.forward_train(lr, coord, cell)
78
+ else:
79
+ return self.forward_test(lr, coord, cell, **kwargs)
80
+
81
+ def forward_train(self, lr, coord, cell):
82
+ feat = self.encoder(lr)
83
+ prob = self.classifier(feat, coord, cell)
84
+ prob = F.softmax(prob, dim=-1) # (b,q,2)
85
+
86
+ pred_heavy = self.heavy_sampler(feat, coord, cell)
87
+ pred_light = self.light_sampler(feat, coord, cell)
88
+ pred = prob[:,:,0:1] * pred_light + prob[:,:,1:2] * pred_heavy
89
+
90
+ res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
91
+ padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
92
+ pred = pred + res
93
+ return pred, prob
94
+
95
+ def forward_test(self, lr, coord, cell, scale=None, hr_size=None, k=0., pixel_batch_size=None, adaptive_cluster=False, refinement=True):
96
+ h,w = lr.shape[-2:]
97
+ if not scale and hr_size:
98
+ H,W = hr_size
99
+ scale = round((H/h + W/w)/2, 1)
100
+ else:
101
+ assert scale and not hr_size
102
+ H,W = round(h*scale), round(w*scale)
103
+ hr_size = (H,W)
104
+
105
+ if scale not in self.cost_list:
106
+ h0,w0 = 16,16
107
+ H0,W0 = round(h0*scale), round(w0*scale)
108
+ inp_coord = make_coord((H0,W0), flatten=True, device='cuda').unsqueeze(0)
109
+ inp_cell = torch.ones_like(inp_coord)
110
+ inp_cell[:,:,0] *= 2/H0
111
+ inp_cell[:,:,1] *= 2/W0
112
+ inp_encoder = torch.zeros((1,3,h0,w0), device='cuda')
113
+ flops_encoder = get_model_flops(self.encoder, inp_encoder)
114
+ inp_sampler = torch.zeros((1,self.encoder.out_dim,h0,w0), device='cuda')
115
+ x = get_model_flops(self.light_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
116
+ y = get_model_flops(self.heavy_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
117
+ cost_list = torch.FloatTensor([x,y]).cuda() + flops_encoder
118
+ cost_list = cost_list / cost_list.sum()
119
+ self.cost_list[scale] = cost_list
120
+ print('cost_list calculated (x{}): {}'.format(scale, cost_list))
121
+ cost_list = self.cost_list[scale]
122
+
123
+ feat = self.encoder(lr)
124
+ b,q = coord.shape[:2]
125
+ assert H*W == q
126
+ tot = b*q
127
+ if not pixel_batch_size:
128
+ pixel_batch_size = q
129
+
130
+ # pre-calculate flag
131
+ prob = torch.zeros((b,q,2), device=lr.device)
132
+ pb = pixel_batch_size//b*b
133
+ assert pb > 0
134
+ l = 0
135
+ while l < q:
136
+ r = min(q, l+pb)
137
+ coord_split = coord[:,l:r,:]
138
+ cell_split = cell[:,l:r,:]
139
+ prob_split = self.classifier(feat, coord_split, cell_split)
140
+ prob[:,l:r] = F.softmax(prob_split, dim=-1)
141
+ l = r
142
+
143
+ if adaptive_cluster: # auto-decide threshold
144
+ diff = prob[:,:,1].view(-1,1) # (tot,1)
145
+ assert diff.max() > diff.min()
146
+ diff = (diff - diff.min()) / (diff.max() - diff.min())
147
+ centroids = torch.FloatTensor([[0.5]]).cuda()
148
+ flag = self.kmeans.fit_predict(diff, centroids=centroids)
149
+ _, min_index = torch.min(diff.flatten(), dim=0)
150
+ if flag[min_index] == 1:
151
+ flag = 1 - flag # (tot,)
152
+ flag = flag.view(b,q)
153
+ else:
154
+ prob = prob / torch.pow(cost_list, k).view(1,1,2)
155
+ flag = torch.argmax(prob, dim=-1) # (b,q)
156
+
157
+ # inference per image
158
+ # more efficient implementation may exist
159
+ preds = []
160
+ for i in range(b):
161
+ pred = torch.zeros((q,3), device=lr.device)
162
+ l = 0
163
+ while l < q:
164
+ r = min(q, l+pixel_batch_size)
165
+ coord_split = coord[i:i+1,l:r,:]
166
+ cell_split = cell[i:i+1,l:r,:]
167
+ flg = flag[i,l:r]
168
+
169
+ idx_easy = torch.where(flg == 0)[0]
170
+ idx_hard = torch.where(flg == 1)[0]
171
+ num_easy, num_hard = len(idx_easy), len(idx_hard)
172
+ if num_easy > 0:
173
+ pred[l+idx_easy] = self.light_sampler(feat[i:i+1], coord_split[:,idx_easy,:], cell_split[:,idx_easy,:]).squeeze(0)
174
+ if num_hard > 0:
175
+ pred[l+idx_hard] = self.heavy_sampler(feat[i:i+1], coord_split[:,idx_hard,:], cell_split[:,idx_hard,:]).squeeze(0)
176
+ res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
177
+ padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
178
+ pred[l:r] += res
179
+ l = r
180
+ preds.append(pred)
181
+ pred = torch.stack(preds, dim=0) # (b,q,3)
182
+
183
+ if refinement:
184
+ pred = pred.transpose(1,2).view(-1,3,H,W)
185
+ pred_unfold = F.pad(pred, (1,1,1,1), mode='replicate')
186
+ pred_unfold = F.unfold(pred_unfold, 3, padding=0).view(-1,3,9,H,W).mean(dim=2) # (b,3,H,W)
187
+ flag = flag.view(-1,1,H,W)
188
+ flag_unfold = F.pad(flag.float(), (1,1,1,1), mode='replicate')
189
+ flag_unfold = F.unfold(flag_unfold, 3, padding=0).view(-1,1,9,H,W).int().sum(dim=2) # (b,1,H,W)
190
+
191
+ cond = (flag==0) & (flag_unfold>0) #
192
+ cond[:,:,[0,-1],:] = cond[:,:,:,[0,-1]] = False
193
+ #print('refined: {} / {}'.format(cond.sum().item(), tot))
194
+ pred = torch.where(cond, pred_unfold, pred)
195
+ pred = pred.view(-1,3,q).transpose(1,2)
196
+ flag = flag.view(b,q,1)
197
+ return pred, flag
models/sampler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import models
6
+ from models import register
7
+ from utils import make_coord
8
+
9
+
10
+ @register('liif-sampler')
11
+ class LIIF_Sampler(nn.Module):
12
+ # feature unfolding, local ensemble not supported
13
+ def __init__(self, imnet_spec, in_dim, out_dim):
14
+ super().__init__()
15
+ self.imnet = models.make(imnet_spec, args={'in_dim': in_dim+4, 'out_dim': out_dim})
16
+
17
+ def make_inp(self, feat, coord, cell):
18
+ feat_coord = make_coord(feat.shape[-2:], flatten=False, device=feat.device)\
19
+ .permute(2,0,1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
20
+ q_feat = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest',
21
+ align_corners=False)[:,:,0,:].permute(0,2,1)
22
+ q_coord = F.grid_sample(feat_coord, coord.flip(-1).unsqueeze(1), mode='nearest',
23
+ align_corners=False)[:,:,0,:].permute(0,2,1)
24
+
25
+ rel_coord = coord - q_coord
26
+ rel_coord[:,:,0] *= feat.shape[-2]
27
+ rel_coord[:,:,1] *= feat.shape[-1]
28
+
29
+ rel_cell = cell.clone()
30
+ rel_cell[:,:,0] *= feat.shape[-2]
31
+ rel_cell[:,:,1] *= feat.shape[-1]
32
+
33
+ inp = torch.cat([q_feat, rel_coord, rel_cell], dim=-1)
34
+ return inp
35
+
36
+ def forward(self, x, coord=None, cell=None):
37
+ if coord is not None:
38
+ x = self.make_inp(x, coord, cell)
39
+ x = self.imnet(x)
40
+ return x
models/utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/XPixelGroup/ClassSR
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.init as init
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def initialize_weights(net_l, scale=1):
10
+ if not isinstance(net_l, list):
11
+ net_l = [net_l]
12
+ for net in net_l:
13
+ for m in net.modules():
14
+ if isinstance(m, nn.Conv2d):
15
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
16
+ m.weight.data *= scale # for residual block
17
+ if m.bias is not None:
18
+ m.bias.data.zero_()
19
+ elif isinstance(m, nn.Linear):
20
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
21
+ m.weight.data *= scale
22
+ if m.bias is not None:
23
+ m.bias.data.zero_()
24
+ elif isinstance(m, nn.BatchNorm2d):
25
+ init.constant_(m.weight, 1)
26
+ init.constant_(m.bias.data, 0.0)
27
+
28
+
29
+ def make_layer(block, n_layers):
30
+ layers = []
31
+ for _ in range(n_layers):
32
+ layers.append(block())
33
+ return nn.Sequential(*layers)
34
+
35
+
36
+ class ResidualBlock_noBN(nn.Module):
37
+ '''Residual block w/o BN
38
+ ---Conv-ReLU-Conv-+-
39
+ |________________|
40
+ '''
41
+
42
+ def __init__(self, nf=64):
43
+ super(ResidualBlock_noBN, self).__init__()
44
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
45
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
46
+
47
+ # initialization
48
+ initialize_weights([self.conv1, self.conv2], 0.1)
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = F.relu(self.conv1(x), inplace=True)
53
+ out = self.conv2(out)
54
+ return identity + out
55
+
56
+
57
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
58
+ return nn.Conv2d(
59
+ in_channels, out_channels, kernel_size,
60
+ padding=(kernel_size//2), bias=bias)
61
+
62
+ class MeanShift(nn.Conv2d):
63
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
64
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
65
+ std = torch.Tensor(rgb_std)
66
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
67
+ self.weight.data.div_(std.view(3, 1, 1, 1))
68
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
69
+ self.bias.data.div_(std)
70
+ self.requires_grad = False
71
+
72
+ class BasicBlock(nn.Sequential):
73
+ def __init__(
74
+ self, in_channels, out_channels, kernel_size, stride=1, bias=False,
75
+ bn=True, act=nn.ReLU(True)):
76
+
77
+ m = [nn.Conv2d(
78
+ in_channels, out_channels, kernel_size,
79
+ padding=(kernel_size//2), stride=stride, bias=bias)
80
+ ]
81
+ if bn: m.append(nn.BatchNorm2d(out_channels))
82
+ if act is not None: m.append(act)
83
+ super(BasicBlock, self).__init__(*m)
84
+
85
+ class ResBlock(nn.Module):
86
+ def __init__(
87
+ self, conv, n_feat, kernel_size,
88
+ bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
89
+
90
+ super(ResBlock, self).__init__()
91
+ m = []
92
+ for i in range(2):
93
+ m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
94
+ if bn: m.append(nn.BatchNorm2d(n_feat))
95
+ if i == 0: m.append(act)
96
+
97
+ self.body = nn.Sequential(*m)
98
+ self.res_scale = res_scale
99
+
100
+ def forward(self, x):
101
+ res = self.body(x).mul(self.res_scale)
102
+ res += x
103
+
104
+ return res
105
+
106
+ class Upsampler(nn.Sequential):
107
+ def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
108
+
109
+ m = []
110
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
111
+ for _ in range(int(math.log(scale, 2))):
112
+ m.append(conv(n_feat, 4 * n_feat, 3, bias))
113
+ m.append(nn.PixelShuffle(2))
114
+ if bn: m.append(nn.BatchNorm2d(n_feat))
115
+ if act: m.append(act())
116
+ elif scale == 3:
117
+ m.append(conv(n_feat, 9 * n_feat, 3, bias))
118
+ m.append(nn.PixelShuffle(3))
119
+ if bn: m.append(nn.BatchNorm2d(n_feat))
120
+ if act: m.append(act())
121
+ else:
122
+ raise NotImplementedError
123
+
124
+ super(Upsampler, self).__init__(*m)
125
+
126
+
127
+ class EResidualBlock(nn.Module):
128
+ def __init__(self,
129
+ in_channels, out_channels,
130
+ group=1):
131
+ super(EResidualBlock, self).__init__()
132
+
133
+ self.body = nn.Sequential(
134
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
135
+ nn.ReLU(inplace=True),
136
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
137
+ nn.ReLU(inplace=True),
138
+ nn.Conv2d(out_channels, out_channels, 1, 1, 0),
139
+ )
140
+
141
+ def forward(self, x):
142
+ out = self.body(x)
143
+ out = F.relu(out + x)
144
+ return out
145
+
146
+
147
+ class Upsampler(nn.Sequential):
148
+ def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
149
+
150
+ m = []
151
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
152
+ for _ in range(int(math.log(scale, 2))):
153
+ m.append(conv(n_feat, 4 * n_feat, 3, bias))
154
+ m.append(nn.PixelShuffle(2))
155
+ if bn: m.append(nn.BatchNorm2d(n_feat))
156
+ if act: m.append(act())
157
+ elif scale == 3:
158
+ m.append(conv(n_feat, 9 * n_feat, 3, bias))
159
+ m.append(nn.PixelShuffle(3))
160
+ if bn: m.append(nn.BatchNorm2d(n_feat))
161
+ if act: m.append(act())
162
+ else:
163
+ raise NotImplementedError
164
+
165
+ super(Upsampler, self).__init__(*m)
166
+
167
+
168
+ class UpsampleBlock(nn.Module):
169
+ def __init__(self,
170
+ n_channels, scale, multi_scale,
171
+ group=1):
172
+ super(UpsampleBlock, self).__init__()
173
+
174
+ if multi_scale:
175
+ self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
176
+ self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
177
+ self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
178
+ else:
179
+ self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
180
+
181
+ self.multi_scale = multi_scale
182
+
183
+ def forward(self, x, scale):
184
+ if self.multi_scale:
185
+ if scale == 2:
186
+ return self.up2(x)
187
+ elif scale == 3:
188
+ return self.up3(x)
189
+ elif scale == 4:
190
+ return self.up4(x)
191
+ else:
192
+ return self.up(x)
193
+
194
+
195
+ class _UpsampleBlock(nn.Module):
196
+ def __init__(self,
197
+ n_channels, scale,
198
+ group=1):
199
+ super(_UpsampleBlock, self).__init__()
200
+
201
+ modules = []
202
+ if scale == 2 or scale == 4 or scale == 8:
203
+ for _ in range(int(math.log(scale, 2))):
204
+ modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
205
+ modules += [nn.PixelShuffle(2)]
206
+ elif scale == 3:
207
+ modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
208
+ modules += [nn.PixelShuffle(3)]
209
+
210
+ self.body = nn.Sequential(*modules)
211
+
212
+ def forward(self, x):
213
+ out = self.body(x)
214
+ return out
utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from collections import OrderedDict
6
+ import pandas as pd
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+ def tensor2numpy(tensor, rgb_range=1.):
11
+ rgb_coefficient = 255 / rgb_range
12
+ img = tensor.mul(rgb_coefficient).clamp(0, 255).round()
13
+ img = img[0].data if img.ndim==4 else img.data
14
+ img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8)
15
+ return img
16
+
17
+ def center_crop(img, size):
18
+ h,w = img.shape[-2:]
19
+ cut_h, cut_w = h-size[0], w-size[1]
20
+
21
+ lh = cut_h // 2
22
+ rh = h - (cut_h - lh)
23
+ lw = cut_w // 2
24
+ rw = w - (cut_w - lw)
25
+
26
+ img = img[:,:, lh:rh, lw:rw]
27
+ return img
28
+
29
+ def make_coord(shape, ranges=None, flatten=True, device='cpu'):
30
+ # Make coordinates at grid centers.
31
+ coord_seqs = []
32
+ for i, n in enumerate(shape):
33
+ if ranges is None:
34
+ v0, v1 = -1, 1
35
+ else:
36
+ v0, v1 = ranges[i]
37
+ r = (v1 - v0) / (2 * n)
38
+ seq = v0 + r + (2 * r) * torch.arange(n, device=device).float()
39
+ coord_seqs.append(seq)
40
+ ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
41
+ if flatten:
42
+ ret = ret.view(-1, ret.shape[-1])
43
+ return ret
44
+
45
+ def compute_num_params(model, text=False):
46
+ tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
47
+ if text:
48
+ if tot >= 1e6:
49
+ return '{:.3f}M'.format(tot / 1e6)
50
+ elif tot >= 1e3:
51
+ return '{:.2f}K'.format(tot / 1e3)
52
+ else:
53
+ return '{}'.format(tot)
54
+ else:
55
+ return tot
56
+
57
+
58
+ def get_names_dict(model):
59
+ """Recursive walk to get names including path."""
60
+ names = {}
61
+
62
+ def _get_names(module, parent_name=""):
63
+ for key, m in module.named_children():
64
+ cls_name = str(m.__class__).split(".")[-1].split("'")[0]
65
+ num_named_children = len(list(m.named_children()))
66
+ if num_named_children > 0:
67
+ name = parent_name + "." + key if parent_name else key
68
+ else:
69
+ name = parent_name + "." + cls_name + "_"+ key if parent_name else key
70
+ names[name] = m
71
+
72
+ if isinstance(m, nn.Module):
73
+ _get_names(m, parent_name=name)
74
+
75
+ _get_names(model)
76
+ return names
77
+
78
+ # https://github.com/chenbong/ARM-Net/blob/main/utils/util.py
79
+ def get_model_flops(model, x, *args, **kwargs):
80
+ """Summarize the given input model.
81
+ Summarized information are 1) output shape, 2) kernel shape,
82
+ 3) number of the parameters and 4) operations (Mult-Adds)
83
+ Args:
84
+ model (Module): Model to summarize
85
+ x (Tensor): Input tensor of the model with [N, C, H, W] shape
86
+ dtype and device have to match to the model
87
+ args, kwargs: Other argument used in `model.forward` function
88
+ """
89
+ model.eval()
90
+ if hasattr(model, 'module'):
91
+ model = model.module
92
+ #x = torch.zeros(input_size).to(next(model.parameters()).device)
93
+
94
+ def register_hook(module):
95
+ def hook(module, inputs, outputs):
96
+ cls_name = str(module.__class__).split(".")[-1].split("'")[0]
97
+ module_idx = len(summary)
98
+ key = None
99
+ for name, item in module_names.items():
100
+ if item == module:
101
+ key = "{}_{}".format(module_idx, name)
102
+ break
103
+ assert key
104
+
105
+ info = OrderedDict()
106
+ info["id"] = id(module)
107
+ if isinstance(outputs, (list, tuple)):
108
+ try:
109
+ info["out"] = list(outputs[0].size())
110
+ except AttributeError:
111
+ info["out"] = list(outputs[0].data.size())
112
+ else:
113
+ info["out"] = list(outputs.size())
114
+
115
+ info["ksize"] = "-"
116
+ info["inner"] = OrderedDict()
117
+ info["params_nt"], info["params"], info["flops"] = 0, 0, 0
118
+
119
+ for name, param in module.named_parameters():
120
+ info["params"] += param.nelement() * param.requires_grad
121
+ info["params_nt"] += param.nelement() * (not param.requires_grad)
122
+
123
+ if name == "weight":
124
+ ksize = list(param.size())
125
+ if len(ksize) > 1:
126
+ ksize[0], ksize[1] = ksize[1], ksize[0]
127
+ info["ksize"] = ksize
128
+
129
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
130
+ assert len(inputs[0].size()) == 4 and len(inputs[0].size()) == len(outputs[0].size())+1
131
+
132
+ in_c, in_h, in_w = inputs[0].size()[1:]
133
+ k_h, k_w = module.kernel_size
134
+ out_c, out_h, out_w = outputs[0].size()
135
+ groups = module.groups
136
+ kernel_mul = k_h * k_w * (in_c // groups)
137
+
138
+ kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups)
139
+ total_mul = kernel_mul_group * groups
140
+ info["flops"] += 2 * total_mul * inputs[0].size()[0] # total
141
+
142
+ elif isinstance(module, nn.BatchNorm2d):
143
+ info["flops"] += 2 * inputs[0].numel()
144
+
145
+ elif isinstance(module, nn.InstanceNorm2d):
146
+ info["flops"] += 6 * inputs[0].numel()
147
+
148
+ elif isinstance(module, nn.LayerNorm):
149
+ info["flops"] += 8 * inputs[0].numel()
150
+
151
+ elif isinstance(module, nn.Linear):
152
+ q = inputs[0].numel() // inputs[0].shape[-1]
153
+ info["flops"] += 2*q * module.in_features * module.out_features # total
154
+
155
+ elif isinstance(module, nn.PReLU) or isinstance(module, nn.ReLU):
156
+ info["flops"] += inputs[0].numel()
157
+ else:
158
+ print('not supported:', module)
159
+ exit()
160
+ info["flops"] += param.nelement()
161
+
162
+ elif "weight" in name:
163
+ info["inner"][name] = list(param.size())
164
+ info["flops"] += param.nelement()
165
+
166
+ if list(module.named_parameters()):
167
+ for v in summary.values():
168
+ if info["id"] == v["id"]:
169
+ info["params"] = "(recursive)"
170
+
171
+ #if info["params"] == 0:
172
+ # info["params"], info["flops"] = "-", "-"
173
+ summary[key] = info
174
+
175
+ if not module._modules:
176
+ hooks.append(module.register_forward_hook(hook))
177
+
178
+ module_names = get_names_dict(model)
179
+ hooks = []
180
+ summary = OrderedDict()
181
+
182
+ model.apply(register_hook)
183
+ try:
184
+ with torch.no_grad():
185
+ model(x) if not (kwargs or args) else model(x, *args, **kwargs)
186
+ finally:
187
+ for hook in hooks:
188
+ hook.remove()
189
+ # Use pandas to align the columns
190
+ df = pd.DataFrame(summary).T
191
+
192
+ df["Mult-Adds"] = pd.to_numeric(df["flops"], errors="coerce")
193
+ df["Params"] = pd.to_numeric(df["params"], errors="coerce")
194
+ df["Non-trainable params"] = pd.to_numeric(df["params_nt"], errors="coerce")
195
+ df = df.rename(columns=dict(
196
+ ksize="Kernel Shape",
197
+ out="Output Shape",
198
+ ))
199
+ return df['Mult-Adds'].sum()
200
+ '''
201
+ with warnings.catch_warnings():
202
+ warnings.filterwarnings('ignore')
203
+ df_sum = df.sum()
204
+
205
+ df.index.name = "Layer"
206
+
207
+ df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
208
+ max_repr_width = max([len(row) for row in df.to_string().split("\n")])
209
+ return df_sum["Mult-Adds"]
210
+ '''