Upload 10 files
Browse files- carn-pcsr-phase1.pth +3 -0
- demo.py +58 -0
- models/__init__.py +3 -0
- models/carn.py +78 -0
- models/mlp.py +32 -0
- models/models.py +23 -0
- models/pcsr.py +197 -0
- models/sampler.py +40 -0
- models/utils.py +214 -0
- utils.py +210 -0
carn-pcsr-phase1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dfe84ddd3923b35d14a977dabda5613a9d59da3b0961e004be786f108d3f8508
|
3 |
+
size 755333
|
demo.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import models
|
3 |
+
from torchvision import transforms
|
4 |
+
from utils import *
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
img_path = '/workspace/datasets/test/myimage/HR/X4/FOTO-BOX-18-1024x1024.png' # only support .png
|
9 |
+
scale = 4 # only support x4
|
10 |
+
|
11 |
+
'''
|
12 |
+
k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
|
13 |
+
adaptive: whether to use automatic decision of k
|
14 |
+
no_refinement: whether not to use pixel-wise refinement (postprocessing for reducing artifacts)
|
15 |
+
parser.add_argument('--opacity', type=float, default=0.65, help='opacity for colored visualization')
|
16 |
+
parser.add_argument('--pixel_batch_size', type=int, default=300000)
|
17 |
+
'''
|
18 |
+
|
19 |
+
resume_path = 'carn-pcsr-phase1.pth'
|
20 |
+
sv_file = torch.load(resume_path)
|
21 |
+
model = models.make(sv_file['model'], load_sd=True).cuda()
|
22 |
+
model.eval()
|
23 |
+
|
24 |
+
rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
|
25 |
+
rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
# prepare inputs
|
29 |
+
lr = transforms.ToTensor()(Image.open(img_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
|
30 |
+
h,w = lr.shape[-2:]
|
31 |
+
H,W = h*scale, w*scale
|
32 |
+
coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
|
33 |
+
cell = torch.ones_like(coord)
|
34 |
+
cell[:,:,0] *= 2/H
|
35 |
+
cell[:,:,1] *= 2/W
|
36 |
+
inp_lr = (lr - rgb_mean) / rgb_std
|
37 |
+
|
38 |
+
pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
|
39 |
+
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
|
40 |
+
flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
|
41 |
+
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
|
42 |
+
max_flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=-25,
|
43 |
+
pixel_batch_size=300000, adaptive_cluster=False, refinement=True)
|
44 |
+
print('flops: {:.1f}G ({:.1f} %) | max_flops: {:.1f}G (100 %)'.format(flops/1e9,
|
45 |
+
(flops / max_flops)*100, max_flops/1e9))
|
46 |
+
|
47 |
+
pred = pred.transpose(1,2).view(-1,3,H,W)
|
48 |
+
pred = pred * rgb_std + rgb_mean
|
49 |
+
pred = tensor2numpy(pred)
|
50 |
+
Image.fromarray(pred).save(f'output.png')
|
51 |
+
|
52 |
+
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
|
53 |
+
H,W = pred.shape[:2]
|
54 |
+
vis_img = np.zeros_like(pred)
|
55 |
+
vis_img[flag[0] == 0] = np.array([0,255,0])
|
56 |
+
vis_img[flag[0] == 1] = np.array([255,0,0])
|
57 |
+
vis_img = vis_img*0.35 + pred*0.65
|
58 |
+
Image.fromarray(vis_img.astype('uint8')).save('output_vis.png')
|
models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .models import register, make
|
2 |
+
from . import mlp, pcsr, sampler
|
3 |
+
from . import carn
|
models/carn.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import models.utils as mutils
|
6 |
+
from models import register
|
7 |
+
|
8 |
+
|
9 |
+
class Block(nn.Module):
|
10 |
+
def __init__(self, nf, group=1):
|
11 |
+
super(Block, self).__init__()
|
12 |
+
self.b1 = mutils.EResidualBlock(nf, nf, group=group)
|
13 |
+
self.c1 = mutils.BasicBlock(nf*2, nf, 1, 1, 0)
|
14 |
+
self.c2 = mutils.BasicBlock(nf*3, nf, 1, 1, 0)
|
15 |
+
self.c3 = mutils.BasicBlock(nf*4, nf, 1, 1, 0)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
c0 = o0 = x
|
19 |
+
|
20 |
+
b1 = self.b1(o0)
|
21 |
+
c1 = torch.cat([c0, b1], dim=1)
|
22 |
+
o1 = self.c1(c1)
|
23 |
+
|
24 |
+
b2 = self.b1(o1)
|
25 |
+
c2 = torch.cat([c1, b2], dim=1)
|
26 |
+
o2 = self.c2(c2)
|
27 |
+
|
28 |
+
b3 = self.b1(o2)
|
29 |
+
c3 = torch.cat([c2, b3], dim=1)
|
30 |
+
o3 = self.c3(c3)
|
31 |
+
|
32 |
+
return o3
|
33 |
+
|
34 |
+
|
35 |
+
@register('carn')
|
36 |
+
class CARN_M(nn.Module):
|
37 |
+
def __init__(self, in_nc=3, out_nc=3, nf=64, scale=4, group=4, no_upsampling=False):
|
38 |
+
super(CARN_M, self).__init__()
|
39 |
+
self.scale = scale
|
40 |
+
self.out_dim = nf
|
41 |
+
|
42 |
+
self.entry = nn.Conv2d(in_nc, nf, 3, 1, 1)
|
43 |
+
self.b1 = Block(nf, group=group)
|
44 |
+
self.b2 = Block(nf, group=group)
|
45 |
+
self.b3 = Block(nf, group=group)
|
46 |
+
|
47 |
+
self.c1 = mutils.BasicBlock(nf*2, nf, 1, 1, 0)
|
48 |
+
self.c2 = mutils.BasicBlock(nf*3, nf, 1, 1, 0)
|
49 |
+
self.c3 = mutils.BasicBlock(nf*4, nf, 1, 1, 0)
|
50 |
+
|
51 |
+
self.no_upsampling = no_upsampling
|
52 |
+
if not no_upsampling:
|
53 |
+
self.upsample = mutils.UpsampleBlock(nf, scale=scale, multi_scale=False, group=group)
|
54 |
+
self.exit = nn.Conv2d(nf, out_nc, 3, 1, 1)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
#x = self.sub_mean(x)
|
58 |
+
x = self.entry(x)
|
59 |
+
c0 = o0 = x
|
60 |
+
|
61 |
+
b1 = self.b1(o0)
|
62 |
+
c1 = torch.cat([c0, b1], dim=1)
|
63 |
+
o1 = self.c1(c1)
|
64 |
+
|
65 |
+
b2 = self.b2(o1)
|
66 |
+
c2 = torch.cat([c1, b2], dim=1)
|
67 |
+
o2 = self.c2(c2)
|
68 |
+
|
69 |
+
b3 = self.b3(o2)
|
70 |
+
c3 = torch.cat([c2, b3], dim=1)
|
71 |
+
o3 = self.c3(c3)
|
72 |
+
out = o3.clone()
|
73 |
+
|
74 |
+
if not self.no_upsampling:
|
75 |
+
out = self.upsample(out, scale=self.scale)
|
76 |
+
out = self.exit(out)
|
77 |
+
#out = self.add_mean(out)
|
78 |
+
return out
|
models/mlp.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models import register
|
5 |
+
|
6 |
+
|
7 |
+
@register('mlp')
|
8 |
+
class MLP(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, in_dim, out_dim, hidden_list, residual=False):
|
11 |
+
super().__init__()
|
12 |
+
self.in_dim = in_dim
|
13 |
+
self.out_dim = out_dim
|
14 |
+
self.hidden_list = hidden_list
|
15 |
+
self.residual = residual
|
16 |
+
if residual:
|
17 |
+
self.convert = nn.Linear(in_dim, out_dim)
|
18 |
+
|
19 |
+
layers = []
|
20 |
+
lastv = in_dim
|
21 |
+
for hidden in hidden_list:
|
22 |
+
layers.append(nn.Linear(lastv, hidden))
|
23 |
+
layers.append(nn.ReLU())
|
24 |
+
lastv = hidden
|
25 |
+
layers.append(nn.Linear(lastv, out_dim))
|
26 |
+
self.layers = nn.Sequential(*layers)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
y = self.layers(x)
|
30 |
+
if self.residual:
|
31 |
+
y = y + self.convert(x)
|
32 |
+
return y
|
models/models.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
|
4 |
+
models = {}
|
5 |
+
|
6 |
+
|
7 |
+
def register(name):
|
8 |
+
def decorator(cls):
|
9 |
+
models[name] = cls
|
10 |
+
return cls
|
11 |
+
return decorator
|
12 |
+
|
13 |
+
|
14 |
+
def make(model_spec, args=None, load_sd=False):
|
15 |
+
if args is not None:
|
16 |
+
model_args = copy.deepcopy(model_spec['args'])
|
17 |
+
model_args.update(args)
|
18 |
+
else:
|
19 |
+
model_args = model_spec['args']
|
20 |
+
model = models[model_spec['name']](**model_args)
|
21 |
+
if load_sd:
|
22 |
+
model.load_state_dict(model_spec['sd'])
|
23 |
+
return model
|
models/pcsr.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import models
|
8 |
+
from models import register
|
9 |
+
from fast_pytorch_kmeans import KMeans
|
10 |
+
from utils import *
|
11 |
+
|
12 |
+
|
13 |
+
@register('pcsr-phase0')
|
14 |
+
class PCSR(nn.Module):
|
15 |
+
def __init__(self, encoder_spec, heavy_sampler_spec):
|
16 |
+
super().__init__()
|
17 |
+
self.encoder = models.make(encoder_spec)
|
18 |
+
in_dim = self.encoder.out_dim
|
19 |
+
self.heavy_sampler = models.make(heavy_sampler_spec,
|
20 |
+
args={'in_dim': in_dim, 'out_dim': 3})
|
21 |
+
|
22 |
+
def forward(self, lr, coord, cell, **kwargs):
|
23 |
+
if self.training:
|
24 |
+
return self.forward_train(lr, coord, cell)
|
25 |
+
else:
|
26 |
+
return self.forward_test(lr, coord, cell, **kwargs)
|
27 |
+
|
28 |
+
def forward_train(self, lr, coord, cell):
|
29 |
+
feat = self.encoder(lr)
|
30 |
+
res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
|
31 |
+
padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
|
32 |
+
pred_heavy = self.heavy_sampler(feat, coord, cell) + res
|
33 |
+
return pred_heavy
|
34 |
+
|
35 |
+
def forward_test(self, lr, coord, cell, pixel_batch_size=None):
|
36 |
+
feat = self.encoder(lr)
|
37 |
+
b,q = coord.shape[:2]
|
38 |
+
tot = b*q
|
39 |
+
if not pixel_batch_size:
|
40 |
+
pixel_batch_size = q
|
41 |
+
|
42 |
+
preds = []
|
43 |
+
for i in range(b): # for each image
|
44 |
+
pred = torch.zeros((q,3), device=lr.device)
|
45 |
+
l = 0
|
46 |
+
while l < q:
|
47 |
+
r = min(q, l+pixel_batch_size)
|
48 |
+
coord_split = coord[i:i+1,l:r,:]
|
49 |
+
cell_split = cell[i:i+1,l:r,:]
|
50 |
+
res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
|
51 |
+
padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
|
52 |
+
pred[l:r] = self.heavy_sampler(feat[i:i+1], coord_split, cell_split) + res
|
53 |
+
l = r
|
54 |
+
preds.append(pred)
|
55 |
+
pred = torch.stack(preds, dim=0)
|
56 |
+
return pred
|
57 |
+
|
58 |
+
|
59 |
+
@register('pcsr-phase1')
|
60 |
+
class PCSR(nn.Module):
|
61 |
+
|
62 |
+
def __init__(self, encoder_spec, heavy_sampler_spec, light_sampler_spec, classifier_spec):
|
63 |
+
super().__init__()
|
64 |
+
self.encoder = models.make(encoder_spec)
|
65 |
+
in_dim = self.encoder.out_dim
|
66 |
+
self.heavy_sampler = models.make(heavy_sampler_spec,
|
67 |
+
args={'in_dim': in_dim, 'out_dim': 3})
|
68 |
+
self.light_sampler = models.make(light_sampler_spec,
|
69 |
+
args={'in_dim': in_dim, 'out_dim': 3})
|
70 |
+
self.classifier = models.make(classifier_spec,
|
71 |
+
args={'in_dim': in_dim, 'out_dim': 2})
|
72 |
+
self.kmeans = KMeans(n_clusters=2, max_iter=20, mode='euclidean', verbose=0)
|
73 |
+
self.cost_list = {}
|
74 |
+
|
75 |
+
def forward(self, lr, coord, cell, **kwargs):
|
76 |
+
if self.training:
|
77 |
+
return self.forward_train(lr, coord, cell)
|
78 |
+
else:
|
79 |
+
return self.forward_test(lr, coord, cell, **kwargs)
|
80 |
+
|
81 |
+
def forward_train(self, lr, coord, cell):
|
82 |
+
feat = self.encoder(lr)
|
83 |
+
prob = self.classifier(feat, coord, cell)
|
84 |
+
prob = F.softmax(prob, dim=-1) # (b,q,2)
|
85 |
+
|
86 |
+
pred_heavy = self.heavy_sampler(feat, coord, cell)
|
87 |
+
pred_light = self.light_sampler(feat, coord, cell)
|
88 |
+
pred = prob[:,:,0:1] * pred_light + prob[:,:,1:2] * pred_heavy
|
89 |
+
|
90 |
+
res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
|
91 |
+
padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
|
92 |
+
pred = pred + res
|
93 |
+
return pred, prob
|
94 |
+
|
95 |
+
def forward_test(self, lr, coord, cell, scale=None, hr_size=None, k=0., pixel_batch_size=None, adaptive_cluster=False, refinement=True):
|
96 |
+
h,w = lr.shape[-2:]
|
97 |
+
if not scale and hr_size:
|
98 |
+
H,W = hr_size
|
99 |
+
scale = round((H/h + W/w)/2, 1)
|
100 |
+
else:
|
101 |
+
assert scale and not hr_size
|
102 |
+
H,W = round(h*scale), round(w*scale)
|
103 |
+
hr_size = (H,W)
|
104 |
+
|
105 |
+
if scale not in self.cost_list:
|
106 |
+
h0,w0 = 16,16
|
107 |
+
H0,W0 = round(h0*scale), round(w0*scale)
|
108 |
+
inp_coord = make_coord((H0,W0), flatten=True, device='cuda').unsqueeze(0)
|
109 |
+
inp_cell = torch.ones_like(inp_coord)
|
110 |
+
inp_cell[:,:,0] *= 2/H0
|
111 |
+
inp_cell[:,:,1] *= 2/W0
|
112 |
+
inp_encoder = torch.zeros((1,3,h0,w0), device='cuda')
|
113 |
+
flops_encoder = get_model_flops(self.encoder, inp_encoder)
|
114 |
+
inp_sampler = torch.zeros((1,self.encoder.out_dim,h0,w0), device='cuda')
|
115 |
+
x = get_model_flops(self.light_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
|
116 |
+
y = get_model_flops(self.heavy_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
|
117 |
+
cost_list = torch.FloatTensor([x,y]).cuda() + flops_encoder
|
118 |
+
cost_list = cost_list / cost_list.sum()
|
119 |
+
self.cost_list[scale] = cost_list
|
120 |
+
print('cost_list calculated (x{}): {}'.format(scale, cost_list))
|
121 |
+
cost_list = self.cost_list[scale]
|
122 |
+
|
123 |
+
feat = self.encoder(lr)
|
124 |
+
b,q = coord.shape[:2]
|
125 |
+
assert H*W == q
|
126 |
+
tot = b*q
|
127 |
+
if not pixel_batch_size:
|
128 |
+
pixel_batch_size = q
|
129 |
+
|
130 |
+
# pre-calculate flag
|
131 |
+
prob = torch.zeros((b,q,2), device=lr.device)
|
132 |
+
pb = pixel_batch_size//b*b
|
133 |
+
assert pb > 0
|
134 |
+
l = 0
|
135 |
+
while l < q:
|
136 |
+
r = min(q, l+pb)
|
137 |
+
coord_split = coord[:,l:r,:]
|
138 |
+
cell_split = cell[:,l:r,:]
|
139 |
+
prob_split = self.classifier(feat, coord_split, cell_split)
|
140 |
+
prob[:,l:r] = F.softmax(prob_split, dim=-1)
|
141 |
+
l = r
|
142 |
+
|
143 |
+
if adaptive_cluster: # auto-decide threshold
|
144 |
+
diff = prob[:,:,1].view(-1,1) # (tot,1)
|
145 |
+
assert diff.max() > diff.min()
|
146 |
+
diff = (diff - diff.min()) / (diff.max() - diff.min())
|
147 |
+
centroids = torch.FloatTensor([[0.5]]).cuda()
|
148 |
+
flag = self.kmeans.fit_predict(diff, centroids=centroids)
|
149 |
+
_, min_index = torch.min(diff.flatten(), dim=0)
|
150 |
+
if flag[min_index] == 1:
|
151 |
+
flag = 1 - flag # (tot,)
|
152 |
+
flag = flag.view(b,q)
|
153 |
+
else:
|
154 |
+
prob = prob / torch.pow(cost_list, k).view(1,1,2)
|
155 |
+
flag = torch.argmax(prob, dim=-1) # (b,q)
|
156 |
+
|
157 |
+
# inference per image
|
158 |
+
# more efficient implementation may exist
|
159 |
+
preds = []
|
160 |
+
for i in range(b):
|
161 |
+
pred = torch.zeros((q,3), device=lr.device)
|
162 |
+
l = 0
|
163 |
+
while l < q:
|
164 |
+
r = min(q, l+pixel_batch_size)
|
165 |
+
coord_split = coord[i:i+1,l:r,:]
|
166 |
+
cell_split = cell[i:i+1,l:r,:]
|
167 |
+
flg = flag[i,l:r]
|
168 |
+
|
169 |
+
idx_easy = torch.where(flg == 0)[0]
|
170 |
+
idx_hard = torch.where(flg == 1)[0]
|
171 |
+
num_easy, num_hard = len(idx_easy), len(idx_hard)
|
172 |
+
if num_easy > 0:
|
173 |
+
pred[l+idx_easy] = self.light_sampler(feat[i:i+1], coord_split[:,idx_easy,:], cell_split[:,idx_easy,:]).squeeze(0)
|
174 |
+
if num_hard > 0:
|
175 |
+
pred[l+idx_hard] = self.heavy_sampler(feat[i:i+1], coord_split[:,idx_hard,:], cell_split[:,idx_hard,:]).squeeze(0)
|
176 |
+
res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
|
177 |
+
padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
|
178 |
+
pred[l:r] += res
|
179 |
+
l = r
|
180 |
+
preds.append(pred)
|
181 |
+
pred = torch.stack(preds, dim=0) # (b,q,3)
|
182 |
+
|
183 |
+
if refinement:
|
184 |
+
pred = pred.transpose(1,2).view(-1,3,H,W)
|
185 |
+
pred_unfold = F.pad(pred, (1,1,1,1), mode='replicate')
|
186 |
+
pred_unfold = F.unfold(pred_unfold, 3, padding=0).view(-1,3,9,H,W).mean(dim=2) # (b,3,H,W)
|
187 |
+
flag = flag.view(-1,1,H,W)
|
188 |
+
flag_unfold = F.pad(flag.float(), (1,1,1,1), mode='replicate')
|
189 |
+
flag_unfold = F.unfold(flag_unfold, 3, padding=0).view(-1,1,9,H,W).int().sum(dim=2) # (b,1,H,W)
|
190 |
+
|
191 |
+
cond = (flag==0) & (flag_unfold>0) #
|
192 |
+
cond[:,:,[0,-1],:] = cond[:,:,:,[0,-1]] = False
|
193 |
+
#print('refined: {} / {}'.format(cond.sum().item(), tot))
|
194 |
+
pred = torch.where(cond, pred_unfold, pred)
|
195 |
+
pred = pred.view(-1,3,q).transpose(1,2)
|
196 |
+
flag = flag.view(b,q,1)
|
197 |
+
return pred, flag
|
models/sampler.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import models
|
6 |
+
from models import register
|
7 |
+
from utils import make_coord
|
8 |
+
|
9 |
+
|
10 |
+
@register('liif-sampler')
|
11 |
+
class LIIF_Sampler(nn.Module):
|
12 |
+
# feature unfolding, local ensemble not supported
|
13 |
+
def __init__(self, imnet_spec, in_dim, out_dim):
|
14 |
+
super().__init__()
|
15 |
+
self.imnet = models.make(imnet_spec, args={'in_dim': in_dim+4, 'out_dim': out_dim})
|
16 |
+
|
17 |
+
def make_inp(self, feat, coord, cell):
|
18 |
+
feat_coord = make_coord(feat.shape[-2:], flatten=False, device=feat.device)\
|
19 |
+
.permute(2,0,1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
|
20 |
+
q_feat = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest',
|
21 |
+
align_corners=False)[:,:,0,:].permute(0,2,1)
|
22 |
+
q_coord = F.grid_sample(feat_coord, coord.flip(-1).unsqueeze(1), mode='nearest',
|
23 |
+
align_corners=False)[:,:,0,:].permute(0,2,1)
|
24 |
+
|
25 |
+
rel_coord = coord - q_coord
|
26 |
+
rel_coord[:,:,0] *= feat.shape[-2]
|
27 |
+
rel_coord[:,:,1] *= feat.shape[-1]
|
28 |
+
|
29 |
+
rel_cell = cell.clone()
|
30 |
+
rel_cell[:,:,0] *= feat.shape[-2]
|
31 |
+
rel_cell[:,:,1] *= feat.shape[-1]
|
32 |
+
|
33 |
+
inp = torch.cat([q_feat, rel_coord, rel_cell], dim=-1)
|
34 |
+
return inp
|
35 |
+
|
36 |
+
def forward(self, x, coord=None, cell=None):
|
37 |
+
if coord is not None:
|
38 |
+
x = self.make_inp(x, coord, cell)
|
39 |
+
x = self.imnet(x)
|
40 |
+
return x
|
models/utils.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/XPixelGroup/ClassSR
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.init as init
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def initialize_weights(net_l, scale=1):
|
10 |
+
if not isinstance(net_l, list):
|
11 |
+
net_l = [net_l]
|
12 |
+
for net in net_l:
|
13 |
+
for m in net.modules():
|
14 |
+
if isinstance(m, nn.Conv2d):
|
15 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
16 |
+
m.weight.data *= scale # for residual block
|
17 |
+
if m.bias is not None:
|
18 |
+
m.bias.data.zero_()
|
19 |
+
elif isinstance(m, nn.Linear):
|
20 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
21 |
+
m.weight.data *= scale
|
22 |
+
if m.bias is not None:
|
23 |
+
m.bias.data.zero_()
|
24 |
+
elif isinstance(m, nn.BatchNorm2d):
|
25 |
+
init.constant_(m.weight, 1)
|
26 |
+
init.constant_(m.bias.data, 0.0)
|
27 |
+
|
28 |
+
|
29 |
+
def make_layer(block, n_layers):
|
30 |
+
layers = []
|
31 |
+
for _ in range(n_layers):
|
32 |
+
layers.append(block())
|
33 |
+
return nn.Sequential(*layers)
|
34 |
+
|
35 |
+
|
36 |
+
class ResidualBlock_noBN(nn.Module):
|
37 |
+
'''Residual block w/o BN
|
38 |
+
---Conv-ReLU-Conv-+-
|
39 |
+
|________________|
|
40 |
+
'''
|
41 |
+
|
42 |
+
def __init__(self, nf=64):
|
43 |
+
super(ResidualBlock_noBN, self).__init__()
|
44 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
45 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
46 |
+
|
47 |
+
# initialization
|
48 |
+
initialize_weights([self.conv1, self.conv2], 0.1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
identity = x
|
52 |
+
out = F.relu(self.conv1(x), inplace=True)
|
53 |
+
out = self.conv2(out)
|
54 |
+
return identity + out
|
55 |
+
|
56 |
+
|
57 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
58 |
+
return nn.Conv2d(
|
59 |
+
in_channels, out_channels, kernel_size,
|
60 |
+
padding=(kernel_size//2), bias=bias)
|
61 |
+
|
62 |
+
class MeanShift(nn.Conv2d):
|
63 |
+
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
|
64 |
+
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
65 |
+
std = torch.Tensor(rgb_std)
|
66 |
+
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
67 |
+
self.weight.data.div_(std.view(3, 1, 1, 1))
|
68 |
+
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
69 |
+
self.bias.data.div_(std)
|
70 |
+
self.requires_grad = False
|
71 |
+
|
72 |
+
class BasicBlock(nn.Sequential):
|
73 |
+
def __init__(
|
74 |
+
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
75 |
+
bn=True, act=nn.ReLU(True)):
|
76 |
+
|
77 |
+
m = [nn.Conv2d(
|
78 |
+
in_channels, out_channels, kernel_size,
|
79 |
+
padding=(kernel_size//2), stride=stride, bias=bias)
|
80 |
+
]
|
81 |
+
if bn: m.append(nn.BatchNorm2d(out_channels))
|
82 |
+
if act is not None: m.append(act)
|
83 |
+
super(BasicBlock, self).__init__(*m)
|
84 |
+
|
85 |
+
class ResBlock(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self, conv, n_feat, kernel_size,
|
88 |
+
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
89 |
+
|
90 |
+
super(ResBlock, self).__init__()
|
91 |
+
m = []
|
92 |
+
for i in range(2):
|
93 |
+
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
94 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
95 |
+
if i == 0: m.append(act)
|
96 |
+
|
97 |
+
self.body = nn.Sequential(*m)
|
98 |
+
self.res_scale = res_scale
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
res = self.body(x).mul(self.res_scale)
|
102 |
+
res += x
|
103 |
+
|
104 |
+
return res
|
105 |
+
|
106 |
+
class Upsampler(nn.Sequential):
|
107 |
+
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
|
108 |
+
|
109 |
+
m = []
|
110 |
+
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
111 |
+
for _ in range(int(math.log(scale, 2))):
|
112 |
+
m.append(conv(n_feat, 4 * n_feat, 3, bias))
|
113 |
+
m.append(nn.PixelShuffle(2))
|
114 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
115 |
+
if act: m.append(act())
|
116 |
+
elif scale == 3:
|
117 |
+
m.append(conv(n_feat, 9 * n_feat, 3, bias))
|
118 |
+
m.append(nn.PixelShuffle(3))
|
119 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
120 |
+
if act: m.append(act())
|
121 |
+
else:
|
122 |
+
raise NotImplementedError
|
123 |
+
|
124 |
+
super(Upsampler, self).__init__(*m)
|
125 |
+
|
126 |
+
|
127 |
+
class EResidualBlock(nn.Module):
|
128 |
+
def __init__(self,
|
129 |
+
in_channels, out_channels,
|
130 |
+
group=1):
|
131 |
+
super(EResidualBlock, self).__init__()
|
132 |
+
|
133 |
+
self.body = nn.Sequential(
|
134 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
|
135 |
+
nn.ReLU(inplace=True),
|
136 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
|
137 |
+
nn.ReLU(inplace=True),
|
138 |
+
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
|
139 |
+
)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
out = self.body(x)
|
143 |
+
out = F.relu(out + x)
|
144 |
+
return out
|
145 |
+
|
146 |
+
|
147 |
+
class Upsampler(nn.Sequential):
|
148 |
+
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
|
149 |
+
|
150 |
+
m = []
|
151 |
+
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
152 |
+
for _ in range(int(math.log(scale, 2))):
|
153 |
+
m.append(conv(n_feat, 4 * n_feat, 3, bias))
|
154 |
+
m.append(nn.PixelShuffle(2))
|
155 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
156 |
+
if act: m.append(act())
|
157 |
+
elif scale == 3:
|
158 |
+
m.append(conv(n_feat, 9 * n_feat, 3, bias))
|
159 |
+
m.append(nn.PixelShuffle(3))
|
160 |
+
if bn: m.append(nn.BatchNorm2d(n_feat))
|
161 |
+
if act: m.append(act())
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
|
165 |
+
super(Upsampler, self).__init__(*m)
|
166 |
+
|
167 |
+
|
168 |
+
class UpsampleBlock(nn.Module):
|
169 |
+
def __init__(self,
|
170 |
+
n_channels, scale, multi_scale,
|
171 |
+
group=1):
|
172 |
+
super(UpsampleBlock, self).__init__()
|
173 |
+
|
174 |
+
if multi_scale:
|
175 |
+
self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
|
176 |
+
self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
|
177 |
+
self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
|
178 |
+
else:
|
179 |
+
self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
|
180 |
+
|
181 |
+
self.multi_scale = multi_scale
|
182 |
+
|
183 |
+
def forward(self, x, scale):
|
184 |
+
if self.multi_scale:
|
185 |
+
if scale == 2:
|
186 |
+
return self.up2(x)
|
187 |
+
elif scale == 3:
|
188 |
+
return self.up3(x)
|
189 |
+
elif scale == 4:
|
190 |
+
return self.up4(x)
|
191 |
+
else:
|
192 |
+
return self.up(x)
|
193 |
+
|
194 |
+
|
195 |
+
class _UpsampleBlock(nn.Module):
|
196 |
+
def __init__(self,
|
197 |
+
n_channels, scale,
|
198 |
+
group=1):
|
199 |
+
super(_UpsampleBlock, self).__init__()
|
200 |
+
|
201 |
+
modules = []
|
202 |
+
if scale == 2 or scale == 4 or scale == 8:
|
203 |
+
for _ in range(int(math.log(scale, 2))):
|
204 |
+
modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
|
205 |
+
modules += [nn.PixelShuffle(2)]
|
206 |
+
elif scale == 3:
|
207 |
+
modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
|
208 |
+
modules += [nn.PixelShuffle(3)]
|
209 |
+
|
210 |
+
self.body = nn.Sequential(*modules)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
out = self.body(x)
|
214 |
+
return out
|
utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from collections import OrderedDict
|
6 |
+
import pandas as pd
|
7 |
+
import warnings
|
8 |
+
warnings.filterwarnings("ignore")
|
9 |
+
|
10 |
+
def tensor2numpy(tensor, rgb_range=1.):
|
11 |
+
rgb_coefficient = 255 / rgb_range
|
12 |
+
img = tensor.mul(rgb_coefficient).clamp(0, 255).round()
|
13 |
+
img = img[0].data if img.ndim==4 else img.data
|
14 |
+
img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8)
|
15 |
+
return img
|
16 |
+
|
17 |
+
def center_crop(img, size):
|
18 |
+
h,w = img.shape[-2:]
|
19 |
+
cut_h, cut_w = h-size[0], w-size[1]
|
20 |
+
|
21 |
+
lh = cut_h // 2
|
22 |
+
rh = h - (cut_h - lh)
|
23 |
+
lw = cut_w // 2
|
24 |
+
rw = w - (cut_w - lw)
|
25 |
+
|
26 |
+
img = img[:,:, lh:rh, lw:rw]
|
27 |
+
return img
|
28 |
+
|
29 |
+
def make_coord(shape, ranges=None, flatten=True, device='cpu'):
|
30 |
+
# Make coordinates at grid centers.
|
31 |
+
coord_seqs = []
|
32 |
+
for i, n in enumerate(shape):
|
33 |
+
if ranges is None:
|
34 |
+
v0, v1 = -1, 1
|
35 |
+
else:
|
36 |
+
v0, v1 = ranges[i]
|
37 |
+
r = (v1 - v0) / (2 * n)
|
38 |
+
seq = v0 + r + (2 * r) * torch.arange(n, device=device).float()
|
39 |
+
coord_seqs.append(seq)
|
40 |
+
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
41 |
+
if flatten:
|
42 |
+
ret = ret.view(-1, ret.shape[-1])
|
43 |
+
return ret
|
44 |
+
|
45 |
+
def compute_num_params(model, text=False):
|
46 |
+
tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
|
47 |
+
if text:
|
48 |
+
if tot >= 1e6:
|
49 |
+
return '{:.3f}M'.format(tot / 1e6)
|
50 |
+
elif tot >= 1e3:
|
51 |
+
return '{:.2f}K'.format(tot / 1e3)
|
52 |
+
else:
|
53 |
+
return '{}'.format(tot)
|
54 |
+
else:
|
55 |
+
return tot
|
56 |
+
|
57 |
+
|
58 |
+
def get_names_dict(model):
|
59 |
+
"""Recursive walk to get names including path."""
|
60 |
+
names = {}
|
61 |
+
|
62 |
+
def _get_names(module, parent_name=""):
|
63 |
+
for key, m in module.named_children():
|
64 |
+
cls_name = str(m.__class__).split(".")[-1].split("'")[0]
|
65 |
+
num_named_children = len(list(m.named_children()))
|
66 |
+
if num_named_children > 0:
|
67 |
+
name = parent_name + "." + key if parent_name else key
|
68 |
+
else:
|
69 |
+
name = parent_name + "." + cls_name + "_"+ key if parent_name else key
|
70 |
+
names[name] = m
|
71 |
+
|
72 |
+
if isinstance(m, nn.Module):
|
73 |
+
_get_names(m, parent_name=name)
|
74 |
+
|
75 |
+
_get_names(model)
|
76 |
+
return names
|
77 |
+
|
78 |
+
# https://github.com/chenbong/ARM-Net/blob/main/utils/util.py
|
79 |
+
def get_model_flops(model, x, *args, **kwargs):
|
80 |
+
"""Summarize the given input model.
|
81 |
+
Summarized information are 1) output shape, 2) kernel shape,
|
82 |
+
3) number of the parameters and 4) operations (Mult-Adds)
|
83 |
+
Args:
|
84 |
+
model (Module): Model to summarize
|
85 |
+
x (Tensor): Input tensor of the model with [N, C, H, W] shape
|
86 |
+
dtype and device have to match to the model
|
87 |
+
args, kwargs: Other argument used in `model.forward` function
|
88 |
+
"""
|
89 |
+
model.eval()
|
90 |
+
if hasattr(model, 'module'):
|
91 |
+
model = model.module
|
92 |
+
#x = torch.zeros(input_size).to(next(model.parameters()).device)
|
93 |
+
|
94 |
+
def register_hook(module):
|
95 |
+
def hook(module, inputs, outputs):
|
96 |
+
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
|
97 |
+
module_idx = len(summary)
|
98 |
+
key = None
|
99 |
+
for name, item in module_names.items():
|
100 |
+
if item == module:
|
101 |
+
key = "{}_{}".format(module_idx, name)
|
102 |
+
break
|
103 |
+
assert key
|
104 |
+
|
105 |
+
info = OrderedDict()
|
106 |
+
info["id"] = id(module)
|
107 |
+
if isinstance(outputs, (list, tuple)):
|
108 |
+
try:
|
109 |
+
info["out"] = list(outputs[0].size())
|
110 |
+
except AttributeError:
|
111 |
+
info["out"] = list(outputs[0].data.size())
|
112 |
+
else:
|
113 |
+
info["out"] = list(outputs.size())
|
114 |
+
|
115 |
+
info["ksize"] = "-"
|
116 |
+
info["inner"] = OrderedDict()
|
117 |
+
info["params_nt"], info["params"], info["flops"] = 0, 0, 0
|
118 |
+
|
119 |
+
for name, param in module.named_parameters():
|
120 |
+
info["params"] += param.nelement() * param.requires_grad
|
121 |
+
info["params_nt"] += param.nelement() * (not param.requires_grad)
|
122 |
+
|
123 |
+
if name == "weight":
|
124 |
+
ksize = list(param.size())
|
125 |
+
if len(ksize) > 1:
|
126 |
+
ksize[0], ksize[1] = ksize[1], ksize[0]
|
127 |
+
info["ksize"] = ksize
|
128 |
+
|
129 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
|
130 |
+
assert len(inputs[0].size()) == 4 and len(inputs[0].size()) == len(outputs[0].size())+1
|
131 |
+
|
132 |
+
in_c, in_h, in_w = inputs[0].size()[1:]
|
133 |
+
k_h, k_w = module.kernel_size
|
134 |
+
out_c, out_h, out_w = outputs[0].size()
|
135 |
+
groups = module.groups
|
136 |
+
kernel_mul = k_h * k_w * (in_c // groups)
|
137 |
+
|
138 |
+
kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups)
|
139 |
+
total_mul = kernel_mul_group * groups
|
140 |
+
info["flops"] += 2 * total_mul * inputs[0].size()[0] # total
|
141 |
+
|
142 |
+
elif isinstance(module, nn.BatchNorm2d):
|
143 |
+
info["flops"] += 2 * inputs[0].numel()
|
144 |
+
|
145 |
+
elif isinstance(module, nn.InstanceNorm2d):
|
146 |
+
info["flops"] += 6 * inputs[0].numel()
|
147 |
+
|
148 |
+
elif isinstance(module, nn.LayerNorm):
|
149 |
+
info["flops"] += 8 * inputs[0].numel()
|
150 |
+
|
151 |
+
elif isinstance(module, nn.Linear):
|
152 |
+
q = inputs[0].numel() // inputs[0].shape[-1]
|
153 |
+
info["flops"] += 2*q * module.in_features * module.out_features # total
|
154 |
+
|
155 |
+
elif isinstance(module, nn.PReLU) or isinstance(module, nn.ReLU):
|
156 |
+
info["flops"] += inputs[0].numel()
|
157 |
+
else:
|
158 |
+
print('not supported:', module)
|
159 |
+
exit()
|
160 |
+
info["flops"] += param.nelement()
|
161 |
+
|
162 |
+
elif "weight" in name:
|
163 |
+
info["inner"][name] = list(param.size())
|
164 |
+
info["flops"] += param.nelement()
|
165 |
+
|
166 |
+
if list(module.named_parameters()):
|
167 |
+
for v in summary.values():
|
168 |
+
if info["id"] == v["id"]:
|
169 |
+
info["params"] = "(recursive)"
|
170 |
+
|
171 |
+
#if info["params"] == 0:
|
172 |
+
# info["params"], info["flops"] = "-", "-"
|
173 |
+
summary[key] = info
|
174 |
+
|
175 |
+
if not module._modules:
|
176 |
+
hooks.append(module.register_forward_hook(hook))
|
177 |
+
|
178 |
+
module_names = get_names_dict(model)
|
179 |
+
hooks = []
|
180 |
+
summary = OrderedDict()
|
181 |
+
|
182 |
+
model.apply(register_hook)
|
183 |
+
try:
|
184 |
+
with torch.no_grad():
|
185 |
+
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
186 |
+
finally:
|
187 |
+
for hook in hooks:
|
188 |
+
hook.remove()
|
189 |
+
# Use pandas to align the columns
|
190 |
+
df = pd.DataFrame(summary).T
|
191 |
+
|
192 |
+
df["Mult-Adds"] = pd.to_numeric(df["flops"], errors="coerce")
|
193 |
+
df["Params"] = pd.to_numeric(df["params"], errors="coerce")
|
194 |
+
df["Non-trainable params"] = pd.to_numeric(df["params_nt"], errors="coerce")
|
195 |
+
df = df.rename(columns=dict(
|
196 |
+
ksize="Kernel Shape",
|
197 |
+
out="Output Shape",
|
198 |
+
))
|
199 |
+
return df['Mult-Adds'].sum()
|
200 |
+
'''
|
201 |
+
with warnings.catch_warnings():
|
202 |
+
warnings.filterwarnings('ignore')
|
203 |
+
df_sum = df.sum()
|
204 |
+
|
205 |
+
df.index.name = "Layer"
|
206 |
+
|
207 |
+
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
|
208 |
+
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
|
209 |
+
return df_sum["Mult-Adds"]
|
210 |
+
'''
|