|
import math |
|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
|
|
from basicsr.utils.registry import ARCH_REGISTRY |
|
from .arch_util import flow_warp |
|
|
|
|
|
class BasicModule(nn.Module): |
|
"""Basic Module for SpyNet. |
|
""" |
|
|
|
def __init__(self): |
|
super(BasicModule, self).__init__() |
|
|
|
self.basic_module = nn.Sequential( |
|
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), |
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), |
|
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), |
|
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), |
|
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) |
|
|
|
def forward(self, tensor_input): |
|
return self.basic_module(tensor_input) |
|
|
|
|
|
@ARCH_REGISTRY.register() |
|
class SpyNet(nn.Module): |
|
"""SpyNet architecture. |
|
|
|
Args: |
|
load_path (str): path for pretrained SpyNet. Default: None. |
|
""" |
|
|
|
def __init__(self, load_path=None): |
|
super(SpyNet, self).__init__() |
|
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) |
|
if load_path: |
|
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) |
|
|
|
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
def preprocess(self, tensor_input): |
|
tensor_output = (tensor_input - self.mean) / self.std |
|
return tensor_output |
|
|
|
def process(self, ref, supp): |
|
flow = [] |
|
|
|
ref = [self.preprocess(ref)] |
|
supp = [self.preprocess(supp)] |
|
|
|
for level in range(5): |
|
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) |
|
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) |
|
|
|
flow = ref[0].new_zeros( |
|
[ref[0].size(0), 2, |
|
int(math.floor(ref[0].size(2) / 2.0)), |
|
int(math.floor(ref[0].size(3) / 2.0))]) |
|
|
|
for level in range(len(ref)): |
|
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 |
|
|
|
if upsampled_flow.size(2) != ref[level].size(2): |
|
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') |
|
if upsampled_flow.size(3) != ref[level].size(3): |
|
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') |
|
|
|
flow = self.basic_module[level](torch.cat([ |
|
ref[level], |
|
flow_warp( |
|
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), |
|
upsampled_flow |
|
], 1)) + upsampled_flow |
|
|
|
return flow |
|
|
|
def forward(self, ref, supp): |
|
assert ref.size() == supp.size() |
|
|
|
h, w = ref.size(2), ref.size(3) |
|
w_floor = math.floor(math.ceil(w / 32.0) * 32.0) |
|
h_floor = math.floor(math.ceil(h / 32.0) * 32.0) |
|
|
|
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) |
|
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) |
|
|
|
flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) |
|
|
|
flow[:, 0, :, :] *= float(w) / float(w_floor) |
|
flow[:, 1, :, :] *= float(h) / float(h_floor) |
|
|
|
return flow |
|
|