Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| from basicsr.utils.registry import ARCH_REGISTRY | |
| from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer | |
| class PCDAlignment(nn.Module): | |
| """Alignment module using Pyramid, Cascading and Deformable convolution | |
| (PCD). It is used in EDVR. | |
| ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks`` | |
| Args: | |
| num_feat (int): Channel number of middle features. Default: 64. | |
| deformable_groups (int): Deformable groups. Defaults: 8. | |
| """ | |
| def __init__(self, num_feat=64, deformable_groups=8): | |
| super(PCDAlignment, self).__init__() | |
| # Pyramid has three levels: | |
| # L3: level 3, 1/4 spatial size | |
| # L2: level 2, 1/2 spatial size | |
| # L1: level 1, original spatial size | |
| self.offset_conv1 = nn.ModuleDict() | |
| self.offset_conv2 = nn.ModuleDict() | |
| self.offset_conv3 = nn.ModuleDict() | |
| self.dcn_pack = nn.ModuleDict() | |
| self.feat_conv = nn.ModuleDict() | |
| # Pyramids | |
| for i in range(3, 0, -1): | |
| level = f'l{i}' | |
| self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) | |
| if i == 3: | |
| self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| else: | |
| self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) | |
| self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) | |
| if i < 3: | |
| self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) | |
| # Cascading dcn | |
| self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) | |
| self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) | |
| def forward(self, nbr_feat_l, ref_feat_l): | |
| """Align neighboring frame features to the reference frame features. | |
| Args: | |
| nbr_feat_l (list[Tensor]): Neighboring feature list. It | |
| contains three pyramid levels (L1, L2, L3), | |
| each with shape (b, c, h, w). | |
| ref_feat_l (list[Tensor]): Reference feature list. It | |
| contains three pyramid levels (L1, L2, L3), | |
| each with shape (b, c, h, w). | |
| Returns: | |
| Tensor: Aligned features. | |
| """ | |
| # Pyramids | |
| upsampled_offset, upsampled_feat = None, None | |
| for i in range(3, 0, -1): | |
| level = f'l{i}' | |
| offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1) | |
| offset = self.lrelu(self.offset_conv1[level](offset)) | |
| if i == 3: | |
| offset = self.lrelu(self.offset_conv2[level](offset)) | |
| else: | |
| offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1))) | |
| offset = self.lrelu(self.offset_conv3[level](offset)) | |
| feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset) | |
| if i < 3: | |
| feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1)) | |
| if i > 1: | |
| feat = self.lrelu(feat) | |
| if i > 1: # upsample offset and features | |
| # x2: when we upsample the offset, we should also enlarge | |
| # the magnitude. | |
| upsampled_offset = self.upsample(offset) * 2 | |
| upsampled_feat = self.upsample(feat) | |
| # Cascading | |
| offset = torch.cat([feat, ref_feat_l[0]], dim=1) | |
| offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset)))) | |
| feat = self.lrelu(self.cas_dcnpack(feat, offset)) | |
| return feat | |
| class TSAFusion(nn.Module): | |
| """Temporal Spatial Attention (TSA) fusion module. | |
| Temporal: Calculate the correlation between center frame and | |
| neighboring frames; | |
| Spatial: It has 3 pyramid levels, the attention is similar to SFT. | |
| (SFT: Recovering realistic texture in image super-resolution by deep | |
| spatial feature transform.) | |
| Args: | |
| num_feat (int): Channel number of middle features. Default: 64. | |
| num_frame (int): Number of frames. Default: 5. | |
| center_frame_idx (int): The index of center frame. Default: 2. | |
| """ | |
| def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2): | |
| super(TSAFusion, self).__init__() | |
| self.center_frame_idx = center_frame_idx | |
| # temporal attention (before fusion conv) | |
| self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) | |
| # spatial attention (after fusion conv) | |
| self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) | |
| self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1) | |
| self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1) | |
| self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1) | |
| self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1) | |
| self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1) | |
| self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) | |
| self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1) | |
| self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
| def forward(self, aligned_feat): | |
| """ | |
| Args: | |
| aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w). | |
| Returns: | |
| Tensor: Features after TSA with the shape (b, c, h, w). | |
| """ | |
| b, t, c, h, w = aligned_feat.size() | |
| # temporal attention | |
| embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone()) | |
| embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w)) | |
| embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w) | |
| corr_l = [] # correlation list | |
| for i in range(t): | |
| emb_neighbor = embedding[:, i, :, :, :] | |
| corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w) | |
| corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w) | |
| corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w) | |
| corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w) | |
| corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w) | |
| aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob | |
| # fusion | |
| feat = self.lrelu(self.feat_fusion(aligned_feat)) | |
| # spatial attention | |
| attn = self.lrelu(self.spatial_attn1(aligned_feat)) | |
| attn_max = self.max_pool(attn) | |
| attn_avg = self.avg_pool(attn) | |
| attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1))) | |
| # pyramid levels | |
| attn_level = self.lrelu(self.spatial_attn_l1(attn)) | |
| attn_max = self.max_pool(attn_level) | |
| attn_avg = self.avg_pool(attn_level) | |
| attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1))) | |
| attn_level = self.lrelu(self.spatial_attn_l3(attn_level)) | |
| attn_level = self.upsample(attn_level) | |
| attn = self.lrelu(self.spatial_attn3(attn)) + attn_level | |
| attn = self.lrelu(self.spatial_attn4(attn)) | |
| attn = self.upsample(attn) | |
| attn = self.spatial_attn5(attn) | |
| attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn))) | |
| attn = torch.sigmoid(attn) | |
| # after initialization, * 2 makes (attn * 2) to be close to 1. | |
| feat = feat * attn * 2 + attn_add | |
| return feat | |
| class PredeblurModule(nn.Module): | |
| """Pre-dublur module. | |
| Args: | |
| num_in_ch (int): Channel number of input image. Default: 3. | |
| num_feat (int): Channel number of intermediate features. Default: 64. | |
| hr_in (bool): Whether the input has high resolution. Default: False. | |
| """ | |
| def __init__(self, num_in_ch=3, num_feat=64, hr_in=False): | |
| super(PredeblurModule, self).__init__() | |
| self.hr_in = hr_in | |
| self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | |
| if self.hr_in: | |
| # downsample x4 by stride conv | |
| self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) | |
| self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) | |
| # generate feature pyramid | |
| self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) | |
| self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) | |
| self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat) | |
| self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat) | |
| self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat) | |
| self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)]) | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) | |
| def forward(self, x): | |
| feat_l1 = self.lrelu(self.conv_first(x)) | |
| if self.hr_in: | |
| feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1)) | |
| feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1)) | |
| # generate feature pyramid | |
| feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1)) | |
| feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2)) | |
| feat_l3 = self.upsample(self.resblock_l3(feat_l3)) | |
| feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3 | |
| feat_l2 = self.upsample(self.resblock_l2_2(feat_l2)) | |
| for i in range(2): | |
| feat_l1 = self.resblock_l1[i](feat_l1) | |
| feat_l1 = feat_l1 + feat_l2 | |
| for i in range(2, 5): | |
| feat_l1 = self.resblock_l1[i](feat_l1) | |
| return feat_l1 | |
| class EDVR(nn.Module): | |
| """EDVR network structure for video super-resolution. | |
| Now only support X4 upsampling factor. | |
| ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks`` | |
| Args: | |
| num_in_ch (int): Channel number of input image. Default: 3. | |
| num_out_ch (int): Channel number of output image. Default: 3. | |
| num_feat (int): Channel number of intermediate features. Default: 64. | |
| num_frame (int): Number of input frames. Default: 5. | |
| deformable_groups (int): Deformable groups. Defaults: 8. | |
| num_extract_block (int): Number of blocks for feature extraction. | |
| Default: 5. | |
| num_reconstruct_block (int): Number of blocks for reconstruction. | |
| Default: 10. | |
| center_frame_idx (int): The index of center frame. Frame counting from | |
| 0. Default: Middle of input frames. | |
| hr_in (bool): Whether the input has high resolution. Default: False. | |
| with_predeblur (bool): Whether has predeblur module. | |
| Default: False. | |
| with_tsa (bool): Whether has TSA module. Default: True. | |
| """ | |
| def __init__(self, | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_frame=5, | |
| deformable_groups=8, | |
| num_extract_block=5, | |
| num_reconstruct_block=10, | |
| center_frame_idx=None, | |
| hr_in=False, | |
| with_predeblur=False, | |
| with_tsa=True): | |
| super(EDVR, self).__init__() | |
| if center_frame_idx is None: | |
| self.center_frame_idx = num_frame // 2 | |
| else: | |
| self.center_frame_idx = center_frame_idx | |
| self.hr_in = hr_in | |
| self.with_predeblur = with_predeblur | |
| self.with_tsa = with_tsa | |
| # extract features for each frame | |
| if self.with_predeblur: | |
| self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in) | |
| self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1) | |
| else: | |
| self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | |
| # extract pyramid features | |
| self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat) | |
| self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) | |
| self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) | |
| self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
| # pcd and tsa module | |
| self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups) | |
| if self.with_tsa: | |
| self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx) | |
| else: | |
| self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) | |
| # reconstruction | |
| self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat) | |
| # upsample | |
| self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) | |
| self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1) | |
| self.pixel_shuffle = nn.PixelShuffle(2) | |
| self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) | |
| self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) | |
| # activation function | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) | |
| def forward(self, x): | |
| b, t, c, h, w = x.size() | |
| if self.hr_in: | |
| assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.') | |
| else: | |
| assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.') | |
| x_center = x[:, self.center_frame_idx, :, :, :].contiguous() | |
| # extract features for each frame | |
| # L1 | |
| if self.with_predeblur: | |
| feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w))) | |
| if self.hr_in: | |
| h, w = h // 4, w // 4 | |
| else: | |
| feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) | |
| feat_l1 = self.feature_extraction(feat_l1) | |
| # L2 | |
| feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) | |
| feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) | |
| # L3 | |
| feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) | |
| feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) | |
| feat_l1 = feat_l1.view(b, t, -1, h, w) | |
| feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2) | |
| feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4) | |
| # PCD alignment | |
| ref_feat_l = [ # reference feature list | |
| feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), | |
| feat_l3[:, self.center_frame_idx, :, :, :].clone() | |
| ] | |
| aligned_feat = [] | |
| for i in range(t): | |
| nbr_feat_l = [ # neighboring feature list | |
| feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() | |
| ] | |
| aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) | |
| aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) | |
| if not self.with_tsa: | |
| aligned_feat = aligned_feat.view(b, -1, h, w) | |
| feat = self.fusion(aligned_feat) | |
| out = self.reconstruction(feat) | |
| out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) | |
| out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) | |
| out = self.lrelu(self.conv_hr(out)) | |
| out = self.conv_last(out) | |
| if self.hr_in: | |
| base = x_center | |
| else: | |
| base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) | |
| out += base | |
| return out | |