Spaces:
Running
Running
File size: 12,595 Bytes
c59c099 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
from .edvr_arch import PCDAlignment, TSAFusion
from .spynet_arch import SpyNet
@ARCH_REGISTRY.register()
class BasicVSR(nn.Module):
"""A recurrent network for video SR. Now only x4 is supported.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
"""
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
super().__init__()
self.num_feat = num_feat
# alignment
self.spynet = SpyNet(spynet_path)
# propagation
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
# reconstruction
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
# activation functions
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def get_flow(self, x):
b, n, c, h, w = x.size()
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
return flows_forward, flows_backward
def forward(self, x):
"""Forward function of BasicVSR.
Args:
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
"""
flows_forward, flows_backward = self.get_flow(x)
b, n, _, h, w = x.size()
# backward branch
out_l = []
feat_prop = x.new_zeros(b, self.num_feat, h, w)
for i in range(n - 1, -1, -1):
x_i = x[:, i, :, :, :]
if i < n - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.backward_trunk(feat_prop)
out_l.insert(0, feat_prop)
# forward branch
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, n):
x_i = x[:, i, :, :, :]
if i > 0:
flow = flows_forward[:, i - 1, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.forward_trunk(feat_prop)
# upsample
out = torch.cat([out_l[i], feat_prop], dim=1)
out = self.lrelu(self.fusion(out))
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
out += base
out_l[i] = out
return torch.stack(out_l, dim=1)
class ConvResidualBlocks(nn.Module):
"""Conv and residual block used in BasicVSR.
Args:
num_in_ch (int): Number of input channels. Default: 3.
num_out_ch (int): Number of output channels. Default: 64.
num_block (int): Number of residual blocks. Default: 15.
"""
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
def forward(self, fea):
return self.main(fea)
@ARCH_REGISTRY.register()
class IconVSR(nn.Module):
"""IconVSR, proposed also in the BasicVSR paper.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15.
keyframe_stride (int): Keyframe stride. Default: 5.
temporal_padding (int): Temporal padding. Default: 2.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
edvr_path (str): Path to the pretrained EDVR model. Default: None.
"""
def __init__(self,
num_feat=64,
num_block=15,
keyframe_stride=5,
temporal_padding=2,
spynet_path=None,
edvr_path=None):
super().__init__()
self.num_feat = num_feat
self.temporal_padding = temporal_padding
self.keyframe_stride = keyframe_stride
# keyframe_branch
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
# alignment
self.spynet = SpyNet(spynet_path)
# propagation
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
# reconstruction
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
# activation functions
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def pad_spatial(self, x):
"""Apply padding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n, t, c, h, w = x.size()
pad_h = (4 - h % 4) % 4
pad_w = (4 - w % 4) % 4
# padding
x = x.view(-1, c, h, w)
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
return x.view(n, t, c, h + pad_h, w + pad_w)
def get_flow(self, x):
b, n, c, h, w = x.size()
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
return flows_forward, flows_backward
def get_keyframe_feature(self, x, keyframe_idx):
if self.temporal_padding == 2:
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
elif self.temporal_padding == 3:
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
x = torch.cat(x, dim=1)
num_frames = 2 * self.temporal_padding + 1
feats_keyframe = {}
for i in keyframe_idx:
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
return feats_keyframe
def forward(self, x):
b, n, _, h_input, w_input = x.size()
x = self.pad_spatial(x)
h, w = x.shape[3:]
keyframe_idx = list(range(0, n, self.keyframe_stride))
if keyframe_idx[-1] != n - 1:
keyframe_idx.append(n - 1) # last frame is a keyframe
# compute flow and keyframe features
flows_forward, flows_backward = self.get_flow(x)
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
# backward branch
out_l = []
feat_prop = x.new_zeros(b, self.num_feat, h, w)
for i in range(n - 1, -1, -1):
x_i = x[:, i, :, :, :]
if i < n - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
feat_prop = self.backward_fusion(feat_prop)
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.backward_trunk(feat_prop)
out_l.insert(0, feat_prop)
# forward branch
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, n):
x_i = x[:, i, :, :, :]
if i > 0:
flow = flows_forward[:, i - 1, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
feat_prop = self.forward_fusion(feat_prop)
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
feat_prop = self.forward_trunk(feat_prop)
# upsample
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
out += base
out_l[i] = out
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
class EDVRFeatureExtractor(nn.Module):
"""EDVR feature extractor used in IconVSR.
Args:
num_input_frame (int): Number of input frames.
num_feat (int): Number of feature channels
load_path (str): Path to the pretrained weights of EDVR. Default: None.
"""
def __init__(self, num_input_frame, num_feat, load_path):
super(EDVRFeatureExtractor, self).__init__()
self.center_frame_idx = num_input_frame // 2
# extract pyramid features
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# pcd and tsa module
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
def forward(self, x):
b, n, c, h, w = x.size()
# extract features for each frame
# L1
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
feat_l1 = self.feature_extraction(feat_l1)
# L2
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
# L3
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
feat_l1 = feat_l1.view(b, n, -1, h, w)
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
# PCD alignment
ref_feat_l = [ # reference feature list
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
feat_l3[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(n):
nbr_feat_l = [ # neighboring feature list
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
# TSA fusion
return self.fusion(aligned_feat)
|