3587jjh commited on
Commit
e87459e
1 Parent(s): 436e09f

Update models/pcsr.py

Browse files
Files changed (1) hide show
  1. models/pcsr.py +197 -196
models/pcsr.py CHANGED
@@ -1,197 +1,198 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- import models
8
- from models import register
9
- from fast_pytorch_kmeans import KMeans
10
- from utils import *
11
-
12
-
13
- @register('pcsr-phase0')
14
- class PCSR(nn.Module):
15
- def __init__(self, encoder_spec, heavy_sampler_spec):
16
- super().__init__()
17
- self.encoder = models.make(encoder_spec)
18
- in_dim = self.encoder.out_dim
19
- self.heavy_sampler = models.make(heavy_sampler_spec,
20
- args={'in_dim': in_dim, 'out_dim': 3})
21
-
22
- def forward(self, lr, coord, cell, **kwargs):
23
- if self.training:
24
- return self.forward_train(lr, coord, cell)
25
- else:
26
- return self.forward_test(lr, coord, cell, **kwargs)
27
-
28
- def forward_train(self, lr, coord, cell):
29
- feat = self.encoder(lr)
30
- res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
31
- padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
32
- pred_heavy = self.heavy_sampler(feat, coord, cell) + res
33
- return pred_heavy
34
-
35
- def forward_test(self, lr, coord, cell, pixel_batch_size=None):
36
- feat = self.encoder(lr)
37
- b,q = coord.shape[:2]
38
- tot = b*q
39
- if not pixel_batch_size:
40
- pixel_batch_size = q
41
-
42
- preds = []
43
- for i in range(b): # for each image
44
- pred = torch.zeros((q,3), device=lr.device)
45
- l = 0
46
- while l < q:
47
- r = min(q, l+pixel_batch_size)
48
- coord_split = coord[i:i+1,l:r,:]
49
- cell_split = cell[i:i+1,l:r,:]
50
- res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
51
- padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
52
- pred[l:r] = self.heavy_sampler(feat[i:i+1], coord_split, cell_split) + res
53
- l = r
54
- preds.append(pred)
55
- pred = torch.stack(preds, dim=0)
56
- return pred
57
-
58
-
59
- @register('pcsr-phase1')
60
- class PCSR(nn.Module):
61
-
62
- def __init__(self, encoder_spec, heavy_sampler_spec, light_sampler_spec, classifier_spec):
63
- super().__init__()
64
- self.encoder = models.make(encoder_spec)
65
- in_dim = self.encoder.out_dim
66
- self.heavy_sampler = models.make(heavy_sampler_spec,
67
- args={'in_dim': in_dim, 'out_dim': 3})
68
- self.light_sampler = models.make(light_sampler_spec,
69
- args={'in_dim': in_dim, 'out_dim': 3})
70
- self.classifier = models.make(classifier_spec,
71
- args={'in_dim': in_dim, 'out_dim': 2})
72
- self.kmeans = KMeans(n_clusters=2, max_iter=20, mode='euclidean', verbose=0)
73
- self.cost_list = {}
74
-
75
- def forward(self, lr, coord, cell, **kwargs):
76
- if self.training:
77
- return self.forward_train(lr, coord, cell)
78
- else:
79
- return self.forward_test(lr, coord, cell, **kwargs)
80
-
81
- def forward_train(self, lr, coord, cell):
82
- feat = self.encoder(lr)
83
- prob = self.classifier(feat, coord, cell)
84
- prob = F.softmax(prob, dim=-1) # (b,q,2)
85
-
86
- pred_heavy = self.heavy_sampler(feat, coord, cell)
87
- pred_light = self.light_sampler(feat, coord, cell)
88
- pred = prob[:,:,0:1] * pred_light + prob[:,:,1:2] * pred_heavy
89
-
90
- res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
91
- padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
92
- pred = pred + res
93
- return pred, prob
94
-
95
- def forward_test(self, lr, coord, cell, scale=None, hr_size=None, k=0., pixel_batch_size=None, adaptive_cluster=False, refinement=True):
96
- h,w = lr.shape[-2:]
97
- if not scale and hr_size:
98
- H,W = hr_size
99
- scale = round((H/h + W/w)/2, 1)
100
- else:
101
- assert scale and not hr_size
102
- H,W = round(h*scale), round(w*scale)
103
- hr_size = (H,W)
104
-
105
- if scale not in self.cost_list:
106
- h0,w0 = 16,16
107
- H0,W0 = round(h0*scale), round(w0*scale)
108
- inp_coord = make_coord((H0,W0), flatten=True, device='cuda').unsqueeze(0)
109
- inp_cell = torch.ones_like(inp_coord)
110
- inp_cell[:,:,0] *= 2/H0
111
- inp_cell[:,:,1] *= 2/W0
112
- inp_encoder = torch.zeros((1,3,h0,w0), device='cuda')
113
- flops_encoder = get_model_flops(self.encoder, inp_encoder)
114
- inp_sampler = torch.zeros((1,self.encoder.out_dim,h0,w0), device='cuda')
115
- x = get_model_flops(self.light_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
116
- y = get_model_flops(self.heavy_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
117
- cost_list = torch.FloatTensor([x,y]).cuda() + flops_encoder
118
- cost_list = cost_list / cost_list.sum()
119
- self.cost_list[scale] = cost_list
120
- print('cost_list calculated (x{}): {}'.format(scale, cost_list))
121
- cost_list = self.cost_list[scale]
122
-
123
- feat = self.encoder(lr)
124
- b,q = coord.shape[:2]
125
- assert H*W == q
126
- tot = b*q
127
- if not pixel_batch_size:
128
- pixel_batch_size = q
129
-
130
- # pre-calculate flag
131
- prob = torch.zeros((b,q,2), device=lr.device)
132
- pb = pixel_batch_size//b*b
133
- assert pb > 0
134
- l = 0
135
- while l < q:
136
- r = min(q, l+pb)
137
- coord_split = coord[:,l:r,:]
138
- cell_split = cell[:,l:r,:]
139
- prob_split = self.classifier(feat, coord_split, cell_split)
140
- prob[:,l:r] = F.softmax(prob_split, dim=-1)
141
- l = r
142
-
143
- if adaptive_cluster: # auto-decide threshold
144
- diff = prob[:,:,1].view(-1,1) # (tot,1)
145
- assert diff.max() > diff.min()
146
- diff = (diff - diff.min()) / (diff.max() - diff.min())
147
- centroids = torch.FloatTensor([[0.5]]).cuda()
148
- flag = self.kmeans.fit_predict(diff, centroids=centroids)
149
- _, min_index = torch.min(diff.flatten(), dim=0)
150
- if flag[min_index] == 1:
151
- flag = 1 - flag # (tot,)
152
- flag = flag.view(b,q)
153
- else:
154
- prob = prob / torch.pow(cost_list, k).view(1,1,2)
155
- flag = torch.argmax(prob, dim=-1) # (b,q)
156
-
157
- # inference per image
158
- # more efficient implementation may exist
159
- preds = []
160
- for i in range(b):
161
- pred = torch.zeros((q,3), device=lr.device)
162
- l = 0
163
- while l < q:
164
- r = min(q, l+pixel_batch_size)
165
- coord_split = coord[i:i+1,l:r,:]
166
- cell_split = cell[i:i+1,l:r,:]
167
- flg = flag[i,l:r]
168
-
169
- idx_easy = torch.where(flg == 0)[0]
170
- idx_hard = torch.where(flg == 1)[0]
171
- num_easy, num_hard = len(idx_easy), len(idx_hard)
172
- if num_easy > 0:
173
- pred[l+idx_easy] = self.light_sampler(feat[i:i+1], coord_split[:,idx_easy,:], cell_split[:,idx_easy,:]).squeeze(0)
174
- if num_hard > 0:
175
- pred[l+idx_hard] = self.heavy_sampler(feat[i:i+1], coord_split[:,idx_hard,:], cell_split[:,idx_hard,:]).squeeze(0)
176
- res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
177
- padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
178
- pred[l:r] += res
179
- l = r
180
- preds.append(pred)
181
- pred = torch.stack(preds, dim=0) # (b,q,3)
182
-
183
- if refinement:
184
- pred = pred.transpose(1,2).view(-1,3,H,W)
185
- pred_unfold = F.pad(pred, (1,1,1,1), mode='replicate')
186
- pred_unfold = F.unfold(pred_unfold, 3, padding=0).view(-1,3,9,H,W).mean(dim=2) # (b,3,H,W)
187
- flag = flag.view(-1,1,H,W)
188
- flag_unfold = F.pad(flag.float(), (1,1,1,1), mode='replicate')
189
- flag_unfold = F.unfold(flag_unfold, 3, padding=0).view(-1,1,9,H,W).int().sum(dim=2) # (b,1,H,W)
190
-
191
- cond = (flag==0) & (flag_unfold>0) #
192
- cond[:,:,[0,-1],:] = cond[:,:,:,[0,-1]] = False
193
- #print('refined: {} / {}'.format(cond.sum().item(), tot))
194
- pred = torch.where(cond, pred_unfold, pred)
195
- pred = pred.view(-1,3,q).transpose(1,2)
196
- flag = flag.view(b,q,1)
 
197
  return pred, flag
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+
8
+ import models
9
+ from models import register
10
+ from fast_pytorch_kmeans import KMeans
11
+ from utils import *
12
+
13
+
14
+ @register('pcsr-phase0')
15
+ class PCSR(nn.Module, PyTorchModelHubMixin):
16
+ def __init__(self, encoder_spec, heavy_sampler_spec):
17
+ super().__init__()
18
+ self.encoder = models.make(encoder_spec)
19
+ in_dim = self.encoder.out_dim
20
+ self.heavy_sampler = models.make(heavy_sampler_spec,
21
+ args={'in_dim': in_dim, 'out_dim': 3})
22
+
23
+ def forward(self, lr, coord, cell, **kwargs):
24
+ if self.training:
25
+ return self.forward_train(lr, coord, cell)
26
+ else:
27
+ return self.forward_test(lr, coord, cell, **kwargs)
28
+
29
+ def forward_train(self, lr, coord, cell):
30
+ feat = self.encoder(lr)
31
+ res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
32
+ padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
33
+ pred_heavy = self.heavy_sampler(feat, coord, cell) + res
34
+ return pred_heavy
35
+
36
+ def forward_test(self, lr, coord, cell, pixel_batch_size=None):
37
+ feat = self.encoder(lr)
38
+ b,q = coord.shape[:2]
39
+ tot = b*q
40
+ if not pixel_batch_size:
41
+ pixel_batch_size = q
42
+
43
+ preds = []
44
+ for i in range(b): # for each image
45
+ pred = torch.zeros((q,3), device=lr.device)
46
+ l = 0
47
+ while l < q:
48
+ r = min(q, l+pixel_batch_size)
49
+ coord_split = coord[i:i+1,l:r,:]
50
+ cell_split = cell[i:i+1,l:r,:]
51
+ res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
52
+ padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
53
+ pred[l:r] = self.heavy_sampler(feat[i:i+1], coord_split, cell_split) + res
54
+ l = r
55
+ preds.append(pred)
56
+ pred = torch.stack(preds, dim=0)
57
+ return pred
58
+
59
+
60
+ @register('pcsr-phase1')
61
+ class PCSR(nn.Module, PyTorchModelHubMixin):
62
+
63
+ def __init__(self, encoder_spec, heavy_sampler_spec, light_sampler_spec, classifier_spec):
64
+ super().__init__()
65
+ self.encoder = models.make(encoder_spec)
66
+ in_dim = self.encoder.out_dim
67
+ self.heavy_sampler = models.make(heavy_sampler_spec,
68
+ args={'in_dim': in_dim, 'out_dim': 3})
69
+ self.light_sampler = models.make(light_sampler_spec,
70
+ args={'in_dim': in_dim, 'out_dim': 3})
71
+ self.classifier = models.make(classifier_spec,
72
+ args={'in_dim': in_dim, 'out_dim': 2})
73
+ self.kmeans = KMeans(n_clusters=2, max_iter=20, mode='euclidean', verbose=0)
74
+ self.cost_list = {}
75
+
76
+ def forward(self, lr, coord, cell, **kwargs):
77
+ if self.training:
78
+ return self.forward_train(lr, coord, cell)
79
+ else:
80
+ return self.forward_test(lr, coord, cell, **kwargs)
81
+
82
+ def forward_train(self, lr, coord, cell):
83
+ feat = self.encoder(lr)
84
+ prob = self.classifier(feat, coord, cell)
85
+ prob = F.softmax(prob, dim=-1) # (b,q,2)
86
+
87
+ pred_heavy = self.heavy_sampler(feat, coord, cell)
88
+ pred_light = self.light_sampler(feat, coord, cell)
89
+ pred = prob[:,:,0:1] * pred_light + prob[:,:,1:2] * pred_heavy
90
+
91
+ res = F.grid_sample(lr, coord.flip(-1).unsqueeze(1), mode='bilinear',
92
+ padding_mode='border', align_corners=False)[:,:,0,:].permute(0,2,1)
93
+ pred = pred + res
94
+ return pred, prob
95
+
96
+ def forward_test(self, lr, coord, cell, scale=None, hr_size=None, k=0., pixel_batch_size=None, adaptive_cluster=False, refinement=True):
97
+ h,w = lr.shape[-2:]
98
+ if not scale and hr_size:
99
+ H,W = hr_size
100
+ scale = round((H/h + W/w)/2, 1)
101
+ else:
102
+ assert scale and not hr_size
103
+ H,W = round(h*scale), round(w*scale)
104
+ hr_size = (H,W)
105
+
106
+ if scale not in self.cost_list:
107
+ h0,w0 = 16,16
108
+ H0,W0 = round(h0*scale), round(w0*scale)
109
+ inp_coord = make_coord((H0,W0), flatten=True, device='cuda').unsqueeze(0)
110
+ inp_cell = torch.ones_like(inp_coord)
111
+ inp_cell[:,:,0] *= 2/H0
112
+ inp_cell[:,:,1] *= 2/W0
113
+ inp_encoder = torch.zeros((1,3,h0,w0), device='cuda')
114
+ flops_encoder = get_model_flops(self.encoder, inp_encoder)
115
+ inp_sampler = torch.zeros((1,self.encoder.out_dim,h0,w0), device='cuda')
116
+ x = get_model_flops(self.light_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
117
+ y = get_model_flops(self.heavy_sampler, inp_sampler, coord=inp_coord, cell=inp_cell)
118
+ cost_list = torch.FloatTensor([x,y]).cuda() + flops_encoder
119
+ cost_list = cost_list / cost_list.sum()
120
+ self.cost_list[scale] = cost_list
121
+ print('cost_list calculated (x{}): {}'.format(scale, cost_list))
122
+ cost_list = self.cost_list[scale]
123
+
124
+ feat = self.encoder(lr)
125
+ b,q = coord.shape[:2]
126
+ assert H*W == q
127
+ tot = b*q
128
+ if not pixel_batch_size:
129
+ pixel_batch_size = q
130
+
131
+ # pre-calculate flag
132
+ prob = torch.zeros((b,q,2), device=lr.device)
133
+ pb = pixel_batch_size//b*b
134
+ assert pb > 0
135
+ l = 0
136
+ while l < q:
137
+ r = min(q, l+pb)
138
+ coord_split = coord[:,l:r,:]
139
+ cell_split = cell[:,l:r,:]
140
+ prob_split = self.classifier(feat, coord_split, cell_split)
141
+ prob[:,l:r] = F.softmax(prob_split, dim=-1)
142
+ l = r
143
+
144
+ if adaptive_cluster: # auto-decide threshold
145
+ diff = prob[:,:,1].view(-1,1) # (tot,1)
146
+ assert diff.max() > diff.min()
147
+ diff = (diff - diff.min()) / (diff.max() - diff.min())
148
+ centroids = torch.FloatTensor([[0.5]]).cuda()
149
+ flag = self.kmeans.fit_predict(diff, centroids=centroids)
150
+ _, min_index = torch.min(diff.flatten(), dim=0)
151
+ if flag[min_index] == 1:
152
+ flag = 1 - flag # (tot,)
153
+ flag = flag.view(b,q)
154
+ else:
155
+ prob = prob / torch.pow(cost_list, k).view(1,1,2)
156
+ flag = torch.argmax(prob, dim=-1) # (b,q)
157
+
158
+ # inference per image
159
+ # more efficient implementation may exist
160
+ preds = []
161
+ for i in range(b):
162
+ pred = torch.zeros((q,3), device=lr.device)
163
+ l = 0
164
+ while l < q:
165
+ r = min(q, l+pixel_batch_size)
166
+ coord_split = coord[i:i+1,l:r,:]
167
+ cell_split = cell[i:i+1,l:r,:]
168
+ flg = flag[i,l:r]
169
+
170
+ idx_easy = torch.where(flg == 0)[0]
171
+ idx_hard = torch.where(flg == 1)[0]
172
+ num_easy, num_hard = len(idx_easy), len(idx_hard)
173
+ if num_easy > 0:
174
+ pred[l+idx_easy] = self.light_sampler(feat[i:i+1], coord_split[:,idx_easy,:], cell_split[:,idx_easy,:]).squeeze(0)
175
+ if num_hard > 0:
176
+ pred[l+idx_hard] = self.heavy_sampler(feat[i:i+1], coord_split[:,idx_hard,:], cell_split[:,idx_hard,:]).squeeze(0)
177
+ res = F.grid_sample(lr[i:i+1], coord_split.flip(-1).unsqueeze(1), mode='bilinear',
178
+ padding_mode='border', align_corners=False)[:,:,0,:].squeeze(0).transpose(0,1)
179
+ pred[l:r] += res
180
+ l = r
181
+ preds.append(pred)
182
+ pred = torch.stack(preds, dim=0) # (b,q,3)
183
+
184
+ if refinement:
185
+ pred = pred.transpose(1,2).view(-1,3,H,W)
186
+ pred_unfold = F.pad(pred, (1,1,1,1), mode='replicate')
187
+ pred_unfold = F.unfold(pred_unfold, 3, padding=0).view(-1,3,9,H,W).mean(dim=2) # (b,3,H,W)
188
+ flag = flag.view(-1,1,H,W)
189
+ flag_unfold = F.pad(flag.float(), (1,1,1,1), mode='replicate')
190
+ flag_unfold = F.unfold(flag_unfold, 3, padding=0).view(-1,1,9,H,W).int().sum(dim=2) # (b,1,H,W)
191
+
192
+ cond = (flag==0) & (flag_unfold>0) #
193
+ cond[:,:,[0,-1],:] = cond[:,:,:,[0,-1]] = False
194
+ #print('refined: {} / {}'.format(cond.sum().item(), tot))
195
+ pred = torch.where(cond, pred_unfold, pred)
196
+ pred = pred.view(-1,3,q).transpose(1,2)
197
+ flag = flag.view(b,q,1)
198
  return pred, flag